Merge remote-tracking branch 'origin/main' into feat/searxng

This commit is contained in:
Vinh Tran
2026-02-26 08:21:09 +01:00
231 changed files with 14763 additions and 3383 deletions
+5 -5
View File
@@ -6,8 +6,8 @@ import "context"
type Tool interface {
Name() string
Description() string
Parameters() map[string]interface{}
Execute(ctx context.Context, args map[string]interface{}) *ToolResult
Parameters() map[string]any
Execute(ctx context.Context, args map[string]any) *ToolResult
}
// ContextualTool is an optional interface that tools can implement
@@ -69,10 +69,10 @@ type AsyncTool interface {
SetCallback(cb AsyncCallback)
}
func ToolToSchema(tool Tool) map[string]interface{} {
return map[string]interface{}{
func ToolToSchema(tool Tool) map[string]any {
return map[string]any{
"type": "function",
"function": map[string]interface{}{
"function": map[string]any{
"name": tool.Name(),
"description": tool.Description(),
"parameters": tool.Parameters(),
+20 -18
View File
@@ -30,7 +30,10 @@ type CronTool struct {
// NewCronTool creates a new CronTool
// execTimeout: 0 means no timeout, >0 sets the timeout duration
func NewCronTool(cronService *cron.CronService, executor JobExecutor, msgBus *bus.MessageBus, workspace string, restrict bool, execTimeout time.Duration, config *config.Config) *CronTool {
func NewCronTool(
cronService *cron.CronService, executor JobExecutor, msgBus *bus.MessageBus, workspace string, restrict bool,
execTimeout time.Duration, config *config.Config,
) *CronTool {
execTool := NewExecToolWithConfig(workspace, restrict, config)
execTool.SetTimeout(execTimeout)
return &CronTool{
@@ -52,40 +55,40 @@ func (t *CronTool) Description() string {
}
// Parameters returns the tool parameters schema
func (t *CronTool) Parameters() map[string]interface{} {
return map[string]interface{}{
func (t *CronTool) Parameters() map[string]any {
return map[string]any{
"type": "object",
"properties": map[string]interface{}{
"action": map[string]interface{}{
"properties": map[string]any{
"action": map[string]any{
"type": "string",
"enum": []string{"add", "list", "remove", "enable", "disable"},
"description": "Action to perform. Use 'add' when user wants to schedule a reminder or task.",
},
"message": map[string]interface{}{
"message": map[string]any{
"type": "string",
"description": "The reminder/task message to display when triggered. If 'command' is used, this describes what the command does.",
},
"command": map[string]interface{}{
"command": map[string]any{
"type": "string",
"description": "Optional: Shell command to execute directly (e.g., 'df -h'). If set, the agent will run this command and report output instead of just showing the message. 'deliver' will be forced to false for commands.",
},
"at_seconds": map[string]interface{}{
"at_seconds": map[string]any{
"type": "integer",
"description": "One-time reminder: seconds from now when to trigger (e.g., 600 for 10 minutes later). Use this for one-time reminders like 'remind me in 10 minutes'.",
},
"every_seconds": map[string]interface{}{
"every_seconds": map[string]any{
"type": "integer",
"description": "Recurring interval in seconds (e.g., 3600 for every hour). Use this ONLY for recurring tasks like 'every 2 hours' or 'daily reminder'.",
},
"cron_expr": map[string]interface{}{
"cron_expr": map[string]any{
"type": "string",
"description": "Cron expression for complex recurring schedules (e.g., '0 9 * * *' for daily at 9am). Use this for complex recurring schedules.",
},
"job_id": map[string]interface{}{
"job_id": map[string]any{
"type": "string",
"description": "Job ID (for remove/enable/disable)",
},
"deliver": map[string]interface{}{
"deliver": map[string]any{
"type": "boolean",
"description": "If true, send message directly to channel. If false, let agent process message (for complex tasks). Default: true",
},
@@ -103,7 +106,7 @@ func (t *CronTool) SetContext(channel, chatID string) {
}
// Execute runs the tool with the given arguments
func (t *CronTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
func (t *CronTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
action, ok := args["action"].(string)
if !ok {
return ErrorResult("action is required")
@@ -125,7 +128,7 @@ func (t *CronTool) Execute(ctx context.Context, args map[string]interface{}) *To
}
}
func (t *CronTool) addJob(args map[string]interface{}) *ToolResult {
func (t *CronTool) addJob(args map[string]any) *ToolResult {
t.mu.RLock()
channel := t.channel
chatID := t.chatID
@@ -233,7 +236,7 @@ func (t *CronTool) listJobs() *ToolResult {
return SilentResult(result)
}
func (t *CronTool) removeJob(args map[string]interface{}) *ToolResult {
func (t *CronTool) removeJob(args map[string]any) *ToolResult {
jobID, ok := args["job_id"].(string)
if !ok || jobID == "" {
return ErrorResult("job_id is required for remove")
@@ -245,7 +248,7 @@ func (t *CronTool) removeJob(args map[string]interface{}) *ToolResult {
return ErrorResult(fmt.Sprintf("Job %s not found", jobID))
}
func (t *CronTool) enableJob(args map[string]interface{}, enable bool) *ToolResult {
func (t *CronTool) enableJob(args map[string]any, enable bool) *ToolResult {
jobID, ok := args["job_id"].(string)
if !ok || jobID == "" {
return ErrorResult("job_id is required for enable/disable")
@@ -279,7 +282,7 @@ func (t *CronTool) ExecuteJob(ctx context.Context, job *cron.CronJob) string {
// Execute command if present
if job.Payload.Command != "" {
args := map[string]interface{}{
args := map[string]any{
"command": job.Payload.Command,
}
@@ -320,7 +323,6 @@ func (t *CronTool) ExecuteJob(ctx context.Context, job *cron.CronJob) string {
channel,
chatID,
)
if err != nil {
return fmt.Sprintf("Error: %v", err)
}
+77 -65
View File
@@ -2,24 +2,27 @@ package tools
import (
"context"
"errors"
"fmt"
"os"
"io/fs"
"strings"
)
// EditFileTool edits a file by replacing old_text with new_text.
// The old_text must exist exactly in the file.
type EditFileTool struct {
allowedDir string
restrict bool
fs fileSystem
}
// NewEditFileTool creates a new EditFileTool with optional directory restriction.
func NewEditFileTool(allowedDir string, restrict bool) *EditFileTool {
return &EditFileTool{
allowedDir: allowedDir,
restrict: restrict,
func NewEditFileTool(workspace string, restrict bool) *EditFileTool {
var fs fileSystem
if restrict {
fs = &sandboxFs{workspace: workspace}
} else {
fs = &hostFs{}
}
return &EditFileTool{fs: fs}
}
func (t *EditFileTool) Name() string {
@@ -30,19 +33,19 @@ func (t *EditFileTool) Description() string {
return "Edit a file by replacing old_text with new_text. The old_text must exist exactly in the file."
}
func (t *EditFileTool) Parameters() map[string]interface{} {
return map[string]interface{}{
func (t *EditFileTool) Parameters() map[string]any {
return map[string]any{
"type": "object",
"properties": map[string]interface{}{
"path": map[string]interface{}{
"properties": map[string]any{
"path": map[string]any{
"type": "string",
"description": "The file path to edit",
},
"old_text": map[string]interface{}{
"old_text": map[string]any{
"type": "string",
"description": "The exact text to find and replace",
},
"new_text": map[string]interface{}{
"new_text": map[string]any{
"type": "string",
"description": "The text to replace with",
},
@@ -51,7 +54,7 @@ func (t *EditFileTool) Parameters() map[string]interface{} {
}
}
func (t *EditFileTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
func (t *EditFileTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
path, ok := args["path"].(string)
if !ok {
return ErrorResult("path is required")
@@ -67,47 +70,24 @@ func (t *EditFileTool) Execute(ctx context.Context, args map[string]interface{})
return ErrorResult("new_text is required")
}
resolvedPath, err := validatePath(path, t.allowedDir, t.restrict)
if err != nil {
if err := editFile(t.fs, path, oldText, newText); err != nil {
return ErrorResult(err.Error())
}
if _, err := os.Stat(resolvedPath); os.IsNotExist(err) {
return ErrorResult(fmt.Sprintf("file not found: %s", path))
}
content, err := os.ReadFile(resolvedPath)
if err != nil {
return ErrorResult(fmt.Sprintf("failed to read file: %v", err))
}
contentStr := string(content)
if !strings.Contains(contentStr, oldText) {
return ErrorResult("old_text not found in file. Make sure it matches exactly")
}
count := strings.Count(contentStr, oldText)
if count > 1 {
return ErrorResult(fmt.Sprintf("old_text appears %d times. Please provide more context to make it unique", count))
}
newContent := strings.Replace(contentStr, oldText, newText, 1)
if err := os.WriteFile(resolvedPath, []byte(newContent), 0644); err != nil {
return ErrorResult(fmt.Sprintf("failed to write file: %v", err))
}
return SilentResult(fmt.Sprintf("File edited: %s", path))
}
type AppendFileTool struct {
workspace string
restrict bool
fs fileSystem
}
func NewAppendFileTool(workspace string, restrict bool) *AppendFileTool {
return &AppendFileTool{workspace: workspace, restrict: restrict}
var fs fileSystem
if restrict {
fs = &sandboxFs{workspace: workspace}
} else {
fs = &hostFs{}
}
return &AppendFileTool{fs: fs}
}
func (t *AppendFileTool) Name() string {
@@ -118,15 +98,15 @@ func (t *AppendFileTool) Description() string {
return "Append content to the end of a file"
}
func (t *AppendFileTool) Parameters() map[string]interface{} {
return map[string]interface{}{
func (t *AppendFileTool) Parameters() map[string]any {
return map[string]any{
"type": "object",
"properties": map[string]interface{}{
"path": map[string]interface{}{
"properties": map[string]any{
"path": map[string]any{
"type": "string",
"description": "The file path to append to",
},
"content": map[string]interface{}{
"content": map[string]any{
"type": "string",
"description": "The content to append",
},
@@ -135,7 +115,7 @@ func (t *AppendFileTool) Parameters() map[string]interface{} {
}
}
func (t *AppendFileTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
func (t *AppendFileTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
path, ok := args["path"].(string)
if !ok {
return ErrorResult("path is required")
@@ -146,20 +126,52 @@ func (t *AppendFileTool) Execute(ctx context.Context, args map[string]interface{
return ErrorResult("content is required")
}
resolvedPath, err := validatePath(path, t.workspace, t.restrict)
if err != nil {
if err := appendFile(t.fs, path, content); err != nil {
return ErrorResult(err.Error())
}
f, err := os.OpenFile(resolvedPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err != nil {
return ErrorResult(fmt.Sprintf("failed to open file: %v", err))
}
defer f.Close()
if _, err := f.WriteString(content); err != nil {
return ErrorResult(fmt.Sprintf("failed to append to file: %v", err))
}
return SilentResult(fmt.Sprintf("Appended to %s", path))
}
// editFile reads the file via sysFs, performs the replacement, and writes back.
// It uses a fileSystem interface, allowing the same logic for both restricted and unrestricted modes.
func editFile(sysFs fileSystem, path, oldText, newText string) error {
content, err := sysFs.ReadFile(path)
if err != nil {
return err
}
newContent, err := replaceEditContent(content, oldText, newText)
if err != nil {
return err
}
return sysFs.WriteFile(path, newContent)
}
// appendFile reads the existing content (if any) via sysFs, appends new content, and writes back.
func appendFile(sysFs fileSystem, path, appendContent string) error {
content, err := sysFs.ReadFile(path)
if err != nil && !errors.Is(err, fs.ErrNotExist) {
return err
}
newContent := append(content, []byte(appendContent)...)
return sysFs.WriteFile(path, newContent)
}
// replaceEditContent handles the core logic of finding and replacing a single occurrence of oldText.
func replaceEditContent(content []byte, oldText, newText string) ([]byte, error) {
contentStr := string(content)
if !strings.Contains(contentStr, oldText) {
return nil, fmt.Errorf("old_text not found in file. Make sure it matches exactly")
}
count := strings.Count(contentStr, oldText)
if count > 1 {
return nil, fmt.Errorf("old_text appears %d times. Please provide more context to make it unique", count)
}
newContent := strings.Replace(contentStr, oldText, newText, 1)
return []byte(newContent), nil
}
+170 -22
View File
@@ -6,17 +6,19 @@ import (
"path/filepath"
"strings"
"testing"
"github.com/stretchr/testify/assert"
)
// TestEditTool_EditFile_Success verifies successful file editing
func TestEditTool_EditFile_Success(t *testing.T) {
tmpDir := t.TempDir()
testFile := filepath.Join(tmpDir, "test.txt")
os.WriteFile(testFile, []byte("Hello World\nThis is a test"), 0644)
os.WriteFile(testFile, []byte("Hello World\nThis is a test"), 0o644)
tool := NewEditFileTool(tmpDir, true)
ctx := context.Background()
args := map[string]interface{}{
args := map[string]any{
"path": testFile,
"old_text": "World",
"new_text": "Universe",
@@ -60,7 +62,7 @@ func TestEditTool_EditFile_NotFound(t *testing.T) {
tool := NewEditFileTool(tmpDir, true)
ctx := context.Background()
args := map[string]interface{}{
args := map[string]any{
"path": testFile,
"old_text": "old",
"new_text": "new",
@@ -83,11 +85,11 @@ func TestEditTool_EditFile_NotFound(t *testing.T) {
func TestEditTool_EditFile_OldTextNotFound(t *testing.T) {
tmpDir := t.TempDir()
testFile := filepath.Join(tmpDir, "test.txt")
os.WriteFile(testFile, []byte("Hello World"), 0644)
os.WriteFile(testFile, []byte("Hello World"), 0o644)
tool := NewEditFileTool(tmpDir, true)
ctx := context.Background()
args := map[string]interface{}{
args := map[string]any{
"path": testFile,
"old_text": "Goodbye",
"new_text": "Hello",
@@ -110,11 +112,11 @@ func TestEditTool_EditFile_OldTextNotFound(t *testing.T) {
func TestEditTool_EditFile_MultipleMatches(t *testing.T) {
tmpDir := t.TempDir()
testFile := filepath.Join(tmpDir, "test.txt")
os.WriteFile(testFile, []byte("test test test"), 0644)
os.WriteFile(testFile, []byte("test test test"), 0o644)
tool := NewEditFileTool(tmpDir, true)
ctx := context.Background()
args := map[string]interface{}{
args := map[string]any{
"path": testFile,
"old_text": "test",
"new_text": "done",
@@ -138,11 +140,11 @@ func TestEditTool_EditFile_OutsideAllowedDir(t *testing.T) {
tmpDir := t.TempDir()
otherDir := t.TempDir()
testFile := filepath.Join(otherDir, "test.txt")
os.WriteFile(testFile, []byte("content"), 0644)
os.WriteFile(testFile, []byte("content"), 0o644)
tool := NewEditFileTool(tmpDir, true) // Restrict to tmpDir
ctx := context.Background()
args := map[string]interface{}{
args := map[string]any{
"path": testFile,
"old_text": "content",
"new_text": "new",
@@ -151,21 +153,25 @@ func TestEditTool_EditFile_OutsideAllowedDir(t *testing.T) {
result := tool.Execute(ctx, args)
// Should return error result
if !result.IsError {
t.Errorf("Expected error when path is outside allowed directory")
}
assert.True(t, result.IsError, "Expected error when path is outside allowed directory")
// Should mention outside allowed directory
if !strings.Contains(result.ForLLM, "outside") && !strings.Contains(result.ForUser, "outside") {
t.Errorf("Expected 'outside allowed' message, got ForLLM: %s", result.ForLLM)
}
// Note: ErrorResult only sets ForLLM by default, so ForUser might be empty.
// We check ForLLM as it's the primary error channel.
assert.True(
t,
strings.Contains(result.ForLLM, "outside") || strings.Contains(result.ForLLM, "access denied") ||
strings.Contains(result.ForLLM, "escapes"),
"Expected 'outside allowed' or 'access denied' message, got ForLLM: %s",
result.ForLLM,
)
}
// TestEditTool_EditFile_MissingPath verifies error handling for missing path
func TestEditTool_EditFile_MissingPath(t *testing.T) {
tool := NewEditFileTool("", false)
ctx := context.Background()
args := map[string]interface{}{
args := map[string]any{
"old_text": "old",
"new_text": "new",
}
@@ -182,7 +188,7 @@ func TestEditTool_EditFile_MissingPath(t *testing.T) {
func TestEditTool_EditFile_MissingOldText(t *testing.T) {
tool := NewEditFileTool("", false)
ctx := context.Background()
args := map[string]interface{}{
args := map[string]any{
"path": "/tmp/test.txt",
"new_text": "new",
}
@@ -199,7 +205,7 @@ func TestEditTool_EditFile_MissingOldText(t *testing.T) {
func TestEditTool_EditFile_MissingNewText(t *testing.T) {
tool := NewEditFileTool("", false)
ctx := context.Background()
args := map[string]interface{}{
args := map[string]any{
"path": "/tmp/test.txt",
"old_text": "old",
}
@@ -216,11 +222,11 @@ func TestEditTool_EditFile_MissingNewText(t *testing.T) {
func TestEditTool_AppendFile_Success(t *testing.T) {
tmpDir := t.TempDir()
testFile := filepath.Join(tmpDir, "test.txt")
os.WriteFile(testFile, []byte("Initial content"), 0644)
os.WriteFile(testFile, []byte("Initial content"), 0o644)
tool := NewAppendFileTool("", false)
ctx := context.Background()
args := map[string]interface{}{
args := map[string]any{
"path": testFile,
"content": "\nAppended content",
}
@@ -260,7 +266,7 @@ func TestEditTool_AppendFile_Success(t *testing.T) {
func TestEditTool_AppendFile_MissingPath(t *testing.T) {
tool := NewAppendFileTool("", false)
ctx := context.Background()
args := map[string]interface{}{
args := map[string]any{
"content": "test",
}
@@ -276,7 +282,7 @@ func TestEditTool_AppendFile_MissingPath(t *testing.T) {
func TestEditTool_AppendFile_MissingContent(t *testing.T) {
tool := NewAppendFileTool("", false)
ctx := context.Background()
args := map[string]interface{}{
args := map[string]any{
"path": "/tmp/test.txt",
}
@@ -287,3 +293,145 @@ func TestEditTool_AppendFile_MissingContent(t *testing.T) {
t.Errorf("Expected error when content is missing")
}
}
// TestReplaceEditContent verifies the helper function replaceEditContent
func TestReplaceEditContent(t *testing.T) {
tests := []struct {
name string
content []byte
oldText string
newText string
expected []byte
expectError bool
}{
{
name: "successful replacement",
content: []byte("hello world"),
oldText: "world",
newText: "universe",
expected: []byte("hello universe"),
expectError: false,
},
{
name: "old text not found",
content: []byte("hello world"),
oldText: "golang",
newText: "rust",
expected: nil,
expectError: true,
},
{
name: "multiple matches found",
content: []byte("test text test"),
oldText: "test",
newText: "done",
expected: nil,
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := replaceEditContent(tt.content, tt.oldText, tt.newText)
if tt.expectError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, tt.expected, result)
}
})
}
}
// TestAppendFileTool_AppendToNonExistent_Restricted verifies that AppendFileTool in restricted mode
// can append to a file that does not yet exist — it should silently create the file.
// This exercises the errors.Is(err, fs.ErrNotExist) path in appendFileWithRW + rootRW.
func TestAppendFileTool_AppendToNonExistent_Restricted(t *testing.T) {
workspace := t.TempDir()
tool := NewAppendFileTool(workspace, true)
ctx := context.Background()
args := map[string]any{
"path": "brand_new_file.txt",
"content": "first content",
}
result := tool.Execute(ctx, args)
assert.False(
t,
result.IsError,
"Expected success when appending to non-existent file in restricted mode, got: %s",
result.ForLLM,
)
// Verify the file was created with correct content
data, err := os.ReadFile(filepath.Join(workspace, "brand_new_file.txt"))
assert.NoError(t, err)
assert.Equal(t, "first content", string(data))
}
// TestAppendFileTool_Restricted_Success verifies that AppendFileTool in restricted mode
// correctly appends to an existing file within the sandbox.
func TestAppendFileTool_Restricted_Success(t *testing.T) {
workspace := t.TempDir()
testFile := "existing.txt"
err := os.WriteFile(filepath.Join(workspace, testFile), []byte("initial"), 0o644)
assert.NoError(t, err)
tool := NewAppendFileTool(workspace, true)
ctx := context.Background()
args := map[string]any{
"path": testFile,
"content": " appended",
}
result := tool.Execute(ctx, args)
assert.False(t, result.IsError, "Expected success, got: %s", result.ForLLM)
assert.True(t, result.Silent)
data, err := os.ReadFile(filepath.Join(workspace, testFile))
assert.NoError(t, err)
assert.Equal(t, "initial appended", string(data))
}
// TestEditFileTool_Restricted_InPlaceEdit verifies that EditFileTool in restricted mode
// correctly edits a file using the single-open editFileInRoot path.
func TestEditFileTool_Restricted_InPlaceEdit(t *testing.T) {
workspace := t.TempDir()
testFile := "edit_target.txt"
err := os.WriteFile(filepath.Join(workspace, testFile), []byte("Hello World"), 0o644)
assert.NoError(t, err)
tool := NewEditFileTool(workspace, true)
ctx := context.Background()
args := map[string]any{
"path": testFile,
"old_text": "World",
"new_text": "Go",
}
result := tool.Execute(ctx, args)
assert.False(t, result.IsError, "Expected success, got: %s", result.ForLLM)
assert.True(t, result.Silent)
data, err := os.ReadFile(filepath.Join(workspace, testFile))
assert.NoError(t, err)
assert.Equal(t, "Hello Go", string(data))
}
// TestEditFileTool_Restricted_FileNotFound verifies that editFileInRoot returns a proper
// error message when the target file does not exist.
func TestEditFileTool_Restricted_FileNotFound(t *testing.T) {
workspace := t.TempDir()
tool := NewEditFileTool(workspace, true)
ctx := context.Background()
args := map[string]any{
"path": "no_such_file.txt",
"old_text": "old",
"new_text": "new",
}
result := tool.Execute(ctx, args)
assert.True(t, result.IsError)
assert.Contains(t, result.ForLLM, "not found")
}
+216 -58
View File
@@ -3,15 +3,17 @@ package tools
import (
"context"
"fmt"
"io/fs"
"os"
"path/filepath"
"strings"
"time"
)
// validatePath ensures the given path is within the workspace if restrict is true.
func validatePath(path, workspace string, restrict bool) (string, error) {
if workspace == "" {
return path, nil
return path, fmt.Errorf("workspace is not defined")
}
absWorkspace, err := filepath.Abs(workspace)
@@ -34,17 +36,19 @@ func validatePath(path, workspace string, restrict bool) (string, error) {
return "", fmt.Errorf("access denied: path is outside the workspace")
}
var resolved string
workspaceReal := absWorkspace
if resolved, err := filepath.EvalSymlinks(absWorkspace); err == nil {
if resolved, err = filepath.EvalSymlinks(absWorkspace); err == nil {
workspaceReal = resolved
}
if resolved, err := filepath.EvalSymlinks(absPath); err == nil {
if resolved, err = filepath.EvalSymlinks(absPath); err == nil {
if !isWithinWorkspace(resolved, workspaceReal) {
return "", fmt.Errorf("access denied: symlink resolves outside workspace")
}
} else if os.IsNotExist(err) {
if parentResolved, err := resolveExistingAncestor(filepath.Dir(absPath)); err == nil {
var parentResolved string
if parentResolved, err = resolveExistingAncestor(filepath.Dir(absPath)); err == nil {
if !isWithinWorkspace(parentResolved, workspaceReal) {
return "", fmt.Errorf("access denied: symlink resolves outside workspace")
}
@@ -74,16 +78,21 @@ func resolveExistingAncestor(path string) (string, error) {
func isWithinWorkspace(candidate, workspace string) bool {
rel, err := filepath.Rel(filepath.Clean(workspace), filepath.Clean(candidate))
return err == nil && rel != ".." && !strings.HasPrefix(rel, ".."+string(os.PathSeparator))
return err == nil && filepath.IsLocal(rel)
}
type ReadFileTool struct {
workspace string
restrict bool
fs fileSystem
}
func NewReadFileTool(workspace string, restrict bool) *ReadFileTool {
return &ReadFileTool{workspace: workspace, restrict: restrict}
var fs fileSystem
if restrict {
fs = &sandboxFs{workspace: workspace}
} else {
fs = &hostFs{}
}
return &ReadFileTool{fs: fs}
}
func (t *ReadFileTool) Name() string {
@@ -94,11 +103,11 @@ func (t *ReadFileTool) Description() string {
return "Read the contents of a file"
}
func (t *ReadFileTool) Parameters() map[string]interface{} {
return map[string]interface{}{
func (t *ReadFileTool) Parameters() map[string]any {
return map[string]any{
"type": "object",
"properties": map[string]interface{}{
"path": map[string]interface{}{
"properties": map[string]any{
"path": map[string]any{
"type": "string",
"description": "Path to the file to read",
},
@@ -107,32 +116,31 @@ func (t *ReadFileTool) Parameters() map[string]interface{} {
}
}
func (t *ReadFileTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
func (t *ReadFileTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
path, ok := args["path"].(string)
if !ok {
return ErrorResult("path is required")
}
resolvedPath, err := validatePath(path, t.workspace, t.restrict)
content, err := t.fs.ReadFile(path)
if err != nil {
return ErrorResult(err.Error())
}
content, err := os.ReadFile(resolvedPath)
if err != nil {
return ErrorResult(fmt.Sprintf("failed to read file: %v", err))
}
return NewToolResult(string(content))
}
type WriteFileTool struct {
workspace string
restrict bool
fs fileSystem
}
func NewWriteFileTool(workspace string, restrict bool) *WriteFileTool {
return &WriteFileTool{workspace: workspace, restrict: restrict}
var fs fileSystem
if restrict {
fs = &sandboxFs{workspace: workspace}
} else {
fs = &hostFs{}
}
return &WriteFileTool{fs: fs}
}
func (t *WriteFileTool) Name() string {
@@ -143,15 +151,15 @@ func (t *WriteFileTool) Description() string {
return "Write content to a file"
}
func (t *WriteFileTool) Parameters() map[string]interface{} {
return map[string]interface{}{
func (t *WriteFileTool) Parameters() map[string]any {
return map[string]any{
"type": "object",
"properties": map[string]interface{}{
"path": map[string]interface{}{
"properties": map[string]any{
"path": map[string]any{
"type": "string",
"description": "Path to the file to write",
},
"content": map[string]interface{}{
"content": map[string]any{
"type": "string",
"description": "Content to write to the file",
},
@@ -160,7 +168,7 @@ func (t *WriteFileTool) Parameters() map[string]interface{} {
}
}
func (t *WriteFileTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
func (t *WriteFileTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
path, ok := args["path"].(string)
if !ok {
return ErrorResult("path is required")
@@ -171,30 +179,25 @@ func (t *WriteFileTool) Execute(ctx context.Context, args map[string]interface{}
return ErrorResult("content is required")
}
resolvedPath, err := validatePath(path, t.workspace, t.restrict)
if err != nil {
if err := t.fs.WriteFile(path, []byte(content)); err != nil {
return ErrorResult(err.Error())
}
dir := filepath.Dir(resolvedPath)
if err := os.MkdirAll(dir, 0755); err != nil {
return ErrorResult(fmt.Sprintf("failed to create directory: %v", err))
}
if err := os.WriteFile(resolvedPath, []byte(content), 0644); err != nil {
return ErrorResult(fmt.Sprintf("failed to write file: %v", err))
}
return SilentResult(fmt.Sprintf("File written: %s", path))
}
type ListDirTool struct {
workspace string
restrict bool
fs fileSystem
}
func NewListDirTool(workspace string, restrict bool) *ListDirTool {
return &ListDirTool{workspace: workspace, restrict: restrict}
var fs fileSystem
if restrict {
fs = &sandboxFs{workspace: workspace}
} else {
fs = &hostFs{}
}
return &ListDirTool{fs: fs}
}
func (t *ListDirTool) Name() string {
@@ -205,11 +208,11 @@ func (t *ListDirTool) Description() string {
return "List files and directories in a path"
}
func (t *ListDirTool) Parameters() map[string]interface{} {
return map[string]interface{}{
func (t *ListDirTool) Parameters() map[string]any {
return map[string]any{
"type": "object",
"properties": map[string]interface{}{
"path": map[string]interface{}{
"properties": map[string]any{
"path": map[string]any{
"type": "string",
"description": "Path to list",
},
@@ -218,30 +221,185 @@ func (t *ListDirTool) Parameters() map[string]interface{} {
}
}
func (t *ListDirTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
func (t *ListDirTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
path, ok := args["path"].(string)
if !ok {
path = "."
}
resolvedPath, err := validatePath(path, t.workspace, t.restrict)
if err != nil {
return ErrorResult(err.Error())
}
entries, err := os.ReadDir(resolvedPath)
entries, err := t.fs.ReadDir(path)
if err != nil {
return ErrorResult(fmt.Sprintf("failed to read directory: %v", err))
}
return formatDirEntries(entries)
}
result := ""
func formatDirEntries(entries []os.DirEntry) *ToolResult {
var result strings.Builder
for _, entry := range entries {
if entry.IsDir() {
result += "DIR: " + entry.Name() + "\n"
result.WriteString("DIR: " + entry.Name() + "\n")
} else {
result += "FILE: " + entry.Name() + "\n"
result.WriteString("FILE: " + entry.Name() + "\n")
}
}
return NewToolResult(result.String())
}
// fileSystem abstracts reading, writing, and listing files, allowing both
// unrestricted (host filesystem) and sandbox (os.Root) implementations to share the same polymorphic interface.
type fileSystem interface {
ReadFile(path string) ([]byte, error)
WriteFile(path string, data []byte) error
ReadDir(path string) ([]os.DirEntry, error)
}
// hostFs is an unrestricted fileReadWriter that operates directly on the host filesystem.
type hostFs struct{}
func (h *hostFs) ReadFile(path string) ([]byte, error) {
content, err := os.ReadFile(path)
if err != nil {
if os.IsNotExist(err) {
return nil, fmt.Errorf("failed to read file: file not found: %w", err)
}
if os.IsPermission(err) {
return nil, fmt.Errorf("failed to read file: access denied: %w", err)
}
return nil, fmt.Errorf("failed to read file: %w", err)
}
return content, nil
}
func (h *hostFs) ReadDir(path string) ([]os.DirEntry, error) {
return os.ReadDir(path)
}
func (h *hostFs) WriteFile(path string, data []byte) error {
dir := filepath.Dir(path)
if err := os.MkdirAll(dir, 0o755); err != nil {
return fmt.Errorf("failed to create parent directories: %w", err)
}
// We use a "write-then-rename" pattern here to ensure an atomic write.
// This prevents the target file from being left in a truncated or partial state
// if the operation is interrupted, as the rename operation is atomic on Linux.
tmpPath := fmt.Sprintf("%s.%d.tmp", path, time.Now().UnixNano())
if err := os.WriteFile(tmpPath, data, 0o644); err != nil {
os.Remove(tmpPath) // Ensure cleanup of partial/empty temp file
return fmt.Errorf("failed to write temp file: %w", err)
}
if err := os.Rename(tmpPath, path); err != nil {
os.Remove(tmpPath)
return fmt.Errorf("failed to replace original file: %w", err)
}
return nil
}
// sandboxFs is a sandboxed fileSystem that operates within a strictly defined workspace using os.Root.
type sandboxFs struct {
workspace string
}
func (r *sandboxFs) execute(path string, fn func(root *os.Root, relPath string) error) error {
if r.workspace == "" {
return fmt.Errorf("workspace is not defined")
}
root, err := os.OpenRoot(r.workspace)
if err != nil {
return fmt.Errorf("failed to open workspace: %w", err)
}
defer root.Close()
relPath, err := getSafeRelPath(r.workspace, path)
if err != nil {
return err
}
return fn(root, relPath)
}
func (r *sandboxFs) ReadFile(path string) ([]byte, error) {
var content []byte
err := r.execute(path, func(root *os.Root, relPath string) error {
fileContent, err := root.ReadFile(relPath)
if err != nil {
if os.IsNotExist(err) {
return fmt.Errorf("failed to read file: file not found: %w", err)
}
// os.Root returns "escapes from parent" for paths outside the root
if os.IsPermission(err) || strings.Contains(err.Error(), "escapes from parent") ||
strings.Contains(err.Error(), "permission denied") {
return fmt.Errorf("failed to read file: access denied: %w", err)
}
return fmt.Errorf("failed to read file: %w", err)
}
content = fileContent
return nil
})
return content, err
}
func (r *sandboxFs) WriteFile(path string, data []byte) error {
return r.execute(path, func(root *os.Root, relPath string) error {
dir := filepath.Dir(relPath)
if dir != "." && dir != "/" {
if err := root.MkdirAll(dir, 0o755); err != nil {
return fmt.Errorf("failed to create parent directories: %w", err)
}
}
// We use a "write-then-rename" pattern here to ensure an atomic write.
// This prevents the target file from being left in a truncated or partial state
// if the operation is interrupted, as the rename operation is atomic on Linux.
tmpRelPath := fmt.Sprintf("%s.%d.tmp", relPath, time.Now().UnixNano())
if err := root.WriteFile(tmpRelPath, data, 0o644); err != nil {
root.Remove(tmpRelPath) // Ensure cleanup of partial/empty temp file
return fmt.Errorf("failed to write to temp file: %w", err)
}
if err := root.Rename(tmpRelPath, relPath); err != nil {
root.Remove(tmpRelPath)
return fmt.Errorf("failed to rename temp file over target: %w", err)
}
return nil
})
}
func (r *sandboxFs) ReadDir(path string) ([]os.DirEntry, error) {
var entries []os.DirEntry
err := r.execute(path, func(root *os.Root, relPath string) error {
dirEntries, err := fs.ReadDir(root.FS(), relPath)
if err != nil {
return err
}
entries = dirEntries
return nil
})
return entries, err
}
// Helper to get a safe relative path for os.Root usage
func getSafeRelPath(workspace, path string) (string, error) {
if workspace == "" {
return "", fmt.Errorf("workspace is not defined")
}
rel := filepath.Clean(path)
if filepath.IsAbs(rel) {
var err error
rel, err = filepath.Rel(workspace, rel)
if err != nil {
return "", fmt.Errorf("failed to calculate relative path: %w", err)
}
}
return NewToolResult(result)
if !filepath.IsLocal(rel) {
return "", fmt.Errorf("path escapes workspace: %s", path)
}
return rel, nil
}
+236 -29
View File
@@ -2,21 +2,24 @@ package tools
import (
"context"
"io"
"os"
"path/filepath"
"strings"
"testing"
"github.com/stretchr/testify/assert"
)
// TestFilesystemTool_ReadFile_Success verifies successful file reading
func TestFilesystemTool_ReadFile_Success(t *testing.T) {
tmpDir := t.TempDir()
testFile := filepath.Join(tmpDir, "test.txt")
os.WriteFile(testFile, []byte("test content"), 0644)
os.WriteFile(testFile, []byte("test content"), 0o644)
tool := &ReadFileTool{}
tool := NewReadFileTool("", false)
ctx := context.Background()
args := map[string]interface{}{
args := map[string]any{
"path": testFile,
}
@@ -41,9 +44,9 @@ func TestFilesystemTool_ReadFile_Success(t *testing.T) {
// TestFilesystemTool_ReadFile_NotFound verifies error handling for missing file
func TestFilesystemTool_ReadFile_NotFound(t *testing.T) {
tool := &ReadFileTool{}
tool := NewReadFileTool("", false)
ctx := context.Background()
args := map[string]interface{}{
args := map[string]any{
"path": "/nonexistent_file_12345.txt",
}
@@ -64,7 +67,7 @@ func TestFilesystemTool_ReadFile_NotFound(t *testing.T) {
func TestFilesystemTool_ReadFile_MissingPath(t *testing.T) {
tool := &ReadFileTool{}
ctx := context.Background()
args := map[string]interface{}{}
args := map[string]any{}
result := tool.Execute(ctx, args)
@@ -84,9 +87,9 @@ func TestFilesystemTool_WriteFile_Success(t *testing.T) {
tmpDir := t.TempDir()
testFile := filepath.Join(tmpDir, "newfile.txt")
tool := &WriteFileTool{}
tool := NewWriteFileTool("", false)
ctx := context.Background()
args := map[string]interface{}{
args := map[string]any{
"path": testFile,
"content": "hello world",
}
@@ -123,9 +126,9 @@ func TestFilesystemTool_WriteFile_CreateDir(t *testing.T) {
tmpDir := t.TempDir()
testFile := filepath.Join(tmpDir, "subdir", "newfile.txt")
tool := &WriteFileTool{}
tool := NewWriteFileTool("", false)
ctx := context.Background()
args := map[string]interface{}{
args := map[string]any{
"path": testFile,
"content": "test",
}
@@ -149,9 +152,9 @@ func TestFilesystemTool_WriteFile_CreateDir(t *testing.T) {
// TestFilesystemTool_WriteFile_MissingPath verifies error handling for missing path
func TestFilesystemTool_WriteFile_MissingPath(t *testing.T) {
tool := &WriteFileTool{}
tool := NewWriteFileTool("", false)
ctx := context.Background()
args := map[string]interface{}{
args := map[string]any{
"content": "test",
}
@@ -165,9 +168,9 @@ func TestFilesystemTool_WriteFile_MissingPath(t *testing.T) {
// TestFilesystemTool_WriteFile_MissingContent verifies error handling for missing content
func TestFilesystemTool_WriteFile_MissingContent(t *testing.T) {
tool := &WriteFileTool{}
tool := NewWriteFileTool("", false)
ctx := context.Background()
args := map[string]interface{}{
args := map[string]any{
"path": "/tmp/test.txt",
}
@@ -179,7 +182,8 @@ func TestFilesystemTool_WriteFile_MissingContent(t *testing.T) {
}
// Should mention required parameter
if !strings.Contains(result.ForLLM, "content is required") && !strings.Contains(result.ForUser, "content is required") {
if !strings.Contains(result.ForLLM, "content is required") &&
!strings.Contains(result.ForUser, "content is required") {
t.Errorf("Expected 'content is required' message, got ForLLM: %s", result.ForLLM)
}
}
@@ -187,13 +191,13 @@ func TestFilesystemTool_WriteFile_MissingContent(t *testing.T) {
// TestFilesystemTool_ListDir_Success verifies successful directory listing
func TestFilesystemTool_ListDir_Success(t *testing.T) {
tmpDir := t.TempDir()
os.WriteFile(filepath.Join(tmpDir, "file1.txt"), []byte("content"), 0644)
os.WriteFile(filepath.Join(tmpDir, "file2.txt"), []byte("content"), 0644)
os.Mkdir(filepath.Join(tmpDir, "subdir"), 0755)
os.WriteFile(filepath.Join(tmpDir, "file1.txt"), []byte("content"), 0o644)
os.WriteFile(filepath.Join(tmpDir, "file2.txt"), []byte("content"), 0o644)
os.Mkdir(filepath.Join(tmpDir, "subdir"), 0o755)
tool := &ListDirTool{}
tool := NewListDirTool("", false)
ctx := context.Background()
args := map[string]interface{}{
args := map[string]any{
"path": tmpDir,
}
@@ -215,9 +219,9 @@ func TestFilesystemTool_ListDir_Success(t *testing.T) {
// TestFilesystemTool_ListDir_NotFound verifies error handling for non-existent directory
func TestFilesystemTool_ListDir_NotFound(t *testing.T) {
tool := &ListDirTool{}
tool := NewListDirTool("", false)
ctx := context.Background()
args := map[string]interface{}{
args := map[string]any{
"path": "/nonexistent_directory_12345",
}
@@ -236,9 +240,9 @@ func TestFilesystemTool_ListDir_NotFound(t *testing.T) {
// TestFilesystemTool_ListDir_DefaultPath verifies default to current directory
func TestFilesystemTool_ListDir_DefaultPath(t *testing.T) {
tool := &ListDirTool{}
tool := NewListDirTool("", false)
ctx := context.Background()
args := map[string]interface{}{}
args := map[string]any{}
result := tool.Execute(ctx, args)
@@ -250,15 +254,14 @@ func TestFilesystemTool_ListDir_DefaultPath(t *testing.T) {
// Block paths that look inside workspace but point outside via symlink.
func TestFilesystemTool_ReadFile_RejectsSymlinkEscape(t *testing.T) {
root := t.TempDir()
workspace := filepath.Join(root, "workspace")
if err := os.MkdirAll(workspace, 0755); err != nil {
if err := os.MkdirAll(workspace, 0o755); err != nil {
t.Fatalf("failed to create workspace: %v", err)
}
secret := filepath.Join(root, "secret.txt")
if err := os.WriteFile(secret, []byte("top secret"), 0644); err != nil {
if err := os.WriteFile(secret, []byte("top secret"), 0o644); err != nil {
t.Fatalf("failed to write secret file: %v", err)
}
@@ -268,14 +271,218 @@ func TestFilesystemTool_ReadFile_RejectsSymlinkEscape(t *testing.T) {
}
tool := NewReadFileTool(workspace, true)
result := tool.Execute(context.Background(), map[string]interface{}{
result := tool.Execute(context.Background(), map[string]any{
"path": link,
})
if !result.IsError {
t.Fatalf("expected symlink escape to be blocked")
}
if !strings.Contains(result.ForLLM, "symlink resolves outside workspace") {
// os.Root might return different errors depending on platform/implementation
// but it definitely should error.
// Our wrapper returns "access denied or file not found"
if !strings.Contains(result.ForLLM, "access denied") && !strings.Contains(result.ForLLM, "file not found") &&
!strings.Contains(result.ForLLM, "no such file") {
t.Fatalf("expected symlink escape error, got: %s", result.ForLLM)
}
}
func TestFilesystemTool_EmptyWorkspace_AccessDenied(t *testing.T) {
tool := NewReadFileTool("", true) // restrict=true but workspace=""
// Try to read a sensitive file (simulated by a temp file outside workspace)
tmpDir := t.TempDir()
secretFile := filepath.Join(tmpDir, "shadow")
os.WriteFile(secretFile, []byte("secret data"), 0o600)
result := tool.Execute(context.Background(), map[string]any{
"path": secretFile,
})
// We EXPECT IsError=true (access blocked due to empty workspace)
assert.True(t, result.IsError, "Security Regression: Empty workspace allowed access! content: %s", result.ForLLM)
// Verify it failed for the right reason
assert.Contains(t, result.ForLLM, "workspace is not defined", "Expected 'workspace is not defined' error")
}
// TestRootMkdirAll verifies that root.MkdirAll (used by atomicWriteFileInRoot) handles all cases:
// single dir, deeply nested dirs, already-existing dirs, and a file blocking a directory path.
func TestRootMkdirAll(t *testing.T) {
workspace := t.TempDir()
root, err := os.OpenRoot(workspace)
if err != nil {
t.Fatalf("failed to open root: %v", err)
}
defer root.Close()
// Case 1: Single directory
err = root.MkdirAll("dir1", 0o755)
assert.NoError(t, err)
_, err = os.Stat(filepath.Join(workspace, "dir1"))
assert.NoError(t, err)
// Case 2: Deeply nested directory
err = root.MkdirAll("a/b/c/d", 0o755)
assert.NoError(t, err)
_, err = os.Stat(filepath.Join(workspace, "a/b/c/d"))
assert.NoError(t, err)
// Case 3: Already exists — must be idempotent
err = root.MkdirAll("a/b/c/d", 0o755)
assert.NoError(t, err)
// Case 4: A regular file blocks directory creation — must error
err = os.WriteFile(filepath.Join(workspace, "file_exists"), []byte("data"), 0o644)
assert.NoError(t, err)
err = root.MkdirAll("file_exists", 0o755)
assert.Error(t, err, "expected error when a file exists at the directory path")
}
func TestFilesystemTool_WriteFile_Restricted_CreateDir(t *testing.T) {
workspace := t.TempDir()
tool := NewWriteFileTool(workspace, true)
ctx := context.Background()
testFile := "deep/nested/path/to/file.txt"
content := "deep content"
args := map[string]any{
"path": testFile,
"content": content,
}
result := tool.Execute(ctx, args)
assert.False(t, result.IsError, "Expected success, got: %s", result.ForLLM)
// Verify file content
actualPath := filepath.Join(workspace, testFile)
data, err := os.ReadFile(actualPath)
assert.NoError(t, err)
assert.Equal(t, content, string(data))
}
// TestHostRW_Read_PermissionDenied verifies that hostRW.Read surfaces access denied errors.
func TestHostRW_Read_PermissionDenied(t *testing.T) {
if os.Getuid() == 0 {
t.Skip("skipping permission test: running as root")
}
tmpDir := t.TempDir()
protected := filepath.Join(tmpDir, "protected.txt")
err := os.WriteFile(protected, []byte("secret"), 0o000)
assert.NoError(t, err)
defer os.Chmod(protected, 0o644) // ensure cleanup
_, err = (&hostFs{}).ReadFile(protected)
assert.Error(t, err)
assert.Contains(t, err.Error(), "access denied")
}
// TestHostRW_Read_Directory verifies that hostRW.Read returns an error when given a directory path.
func TestHostRW_Read_Directory(t *testing.T) {
tmpDir := t.TempDir()
_, err := (&hostFs{}).ReadFile(tmpDir)
assert.Error(t, err, "expected error when reading a directory as a file")
}
// TestRootRW_Read_Directory verifies that rootRW.Read returns an error when given a directory.
func TestRootRW_Read_Directory(t *testing.T) {
workspace := t.TempDir()
root, err := os.OpenRoot(workspace)
assert.NoError(t, err)
defer root.Close()
// Create a subdirectory
err = root.Mkdir("subdir", 0o755)
assert.NoError(t, err)
_, err = (&sandboxFs{workspace: workspace}).ReadFile("subdir")
assert.Error(t, err, "expected error when reading a directory as a file")
}
// TestHostRW_Write_ParentDirMissing verifies that hostRW.Write creates parent dirs automatically.
func TestHostRW_Write_ParentDirMissing(t *testing.T) {
tmpDir := t.TempDir()
target := filepath.Join(tmpDir, "a", "b", "c", "file.txt")
err := (&hostFs{}).WriteFile(target, []byte("hello"))
assert.NoError(t, err)
data, err := os.ReadFile(target)
assert.NoError(t, err)
assert.Equal(t, "hello", string(data))
}
// TestRootRW_Write_ParentDirMissing verifies that rootRW.Write creates
// nested parent directories automatically within the sandbox.
func TestRootRW_Write_ParentDirMissing(t *testing.T) {
workspace := t.TempDir()
relPath := "x/y/z/file.txt"
err := (&sandboxFs{workspace: workspace}).WriteFile(relPath, []byte("nested"))
assert.NoError(t, err)
data, err := os.ReadFile(filepath.Join(workspace, relPath))
assert.NoError(t, err)
assert.Equal(t, "nested", string(data))
}
// TestHostRW_Write verifies the hostRW.Write helper function
func TestHostRW_Write(t *testing.T) {
tmpDir := t.TempDir()
testFile := filepath.Join(tmpDir, "atomic_test.txt")
testData := []byte("atomic test content")
err := (&hostFs{}).WriteFile(testFile, testData)
assert.NoError(t, err)
content, err := os.ReadFile(testFile)
assert.NoError(t, err)
assert.Equal(t, testData, content)
// Verify it overwrites correctly
newData := []byte("new atomic content")
err = (&hostFs{}).WriteFile(testFile, newData)
assert.NoError(t, err)
content, err = os.ReadFile(testFile)
assert.NoError(t, err)
assert.Equal(t, newData, content)
}
// TestRootRW_Write verifies the rootRW.Write helper function
func TestRootRW_Write(t *testing.T) {
tmpDir := t.TempDir()
relPath := "atomic_root_test.txt"
testData := []byte("atomic root test content")
erw := &sandboxFs{workspace: tmpDir}
err := erw.WriteFile(relPath, testData)
assert.NoError(t, err)
root, err := os.OpenRoot(tmpDir)
assert.NoError(t, err)
defer root.Close()
f, err := root.Open(relPath)
assert.NoError(t, err)
defer f.Close()
content, err := io.ReadAll(f)
assert.NoError(t, err)
assert.Equal(t, testData, content)
// Verify it overwrites correctly
newData := []byte("new root atomic content")
err = erw.WriteFile(relPath, newData)
assert.NoError(t, err)
f2, err := root.Open(relPath)
assert.NoError(t, err)
defer f2.Close()
content, err = io.ReadAll(f2)
assert.NoError(t, err)
assert.Equal(t, newData, content)
}
+25 -15
View File
@@ -24,37 +24,37 @@ func (t *I2CTool) Description() string {
return "Interact with I2C bus devices for reading sensors and controlling peripherals. Actions: detect (list buses), scan (find devices on a bus), read (read bytes from device), write (send bytes to device). Linux only."
}
func (t *I2CTool) Parameters() map[string]interface{} {
return map[string]interface{}{
func (t *I2CTool) Parameters() map[string]any {
return map[string]any{
"type": "object",
"properties": map[string]interface{}{
"action": map[string]interface{}{
"properties": map[string]any{
"action": map[string]any{
"type": "string",
"enum": []string{"detect", "scan", "read", "write"},
"description": "Action to perform: detect (list available I2C buses), scan (find devices on a bus), read (read bytes from a device), write (send bytes to a device)",
},
"bus": map[string]interface{}{
"bus": map[string]any{
"type": "string",
"description": "I2C bus number (e.g. \"1\" for /dev/i2c-1). Required for scan/read/write.",
},
"address": map[string]interface{}{
"address": map[string]any{
"type": "integer",
"description": "7-bit I2C device address (0x03-0x77). Required for read/write.",
},
"register": map[string]interface{}{
"register": map[string]any{
"type": "integer",
"description": "Register address to read from or write to. If set, sends register byte before read/write.",
},
"data": map[string]interface{}{
"data": map[string]any{
"type": "array",
"items": map[string]interface{}{"type": "integer"},
"items": map[string]any{"type": "integer"},
"description": "Bytes to write (0-255 each). Required for write action.",
},
"length": map[string]interface{}{
"length": map[string]any{
"type": "integer",
"description": "Number of bytes to read (1-256). Default: 1. Used with read action.",
},
"confirm": map[string]interface{}{
"confirm": map[string]any{
"type": "boolean",
"description": "Must be true for write operations. Safety guard to prevent accidental writes.",
},
@@ -63,7 +63,7 @@ func (t *I2CTool) Parameters() map[string]interface{} {
}
}
func (t *I2CTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
func (t *I2CTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
if runtime.GOOS != "linux" {
return ErrorResult("I2C is only supported on Linux. This tool requires /dev/i2c-* device files.")
}
@@ -95,7 +95,9 @@ func (t *I2CTool) detect() *ToolResult {
}
if len(matches) == 0 {
return SilentResult("No I2C buses found. You may need to:\n1. Load the i2c-dev module: modprobe i2c-dev\n2. Check that I2C is enabled in device tree\n3. Configure pinmux for your board (see hardware skill)")
return SilentResult(
"No I2C buses found. You may need to:\n1. Load the i2c-dev module: modprobe i2c-dev\n2. Check that I2C is enabled in device tree\n3. Configure pinmux for your board (see hardware skill)",
)
}
type busInfo struct {
@@ -115,14 +117,20 @@ func (t *I2CTool) detect() *ToolResult {
return SilentResult(fmt.Sprintf("Found %d I2C bus(es):\n%s", len(buses), string(result)))
}
// Helper functions for I2C operations (used by platform-specific implementations)
// isValidBusID checks that a bus identifier is a simple number (prevents path injection)
//
//nolint:unused // Used by i2c_linux.go
func isValidBusID(id string) bool {
matched, _ := regexp.MatchString(`^\d+$`, id)
return matched
}
// parseI2CAddress extracts and validates an I2C address from args
func parseI2CAddress(args map[string]interface{}) (int, *ToolResult) {
//
//nolint:unused // Used by i2c_linux.go
func parseI2CAddress(args map[string]any) (int, *ToolResult) {
addrFloat, ok := args["address"].(float64)
if !ok {
return 0, ErrorResult("address is required (e.g. 0x38 for AHT20)")
@@ -135,7 +143,9 @@ func parseI2CAddress(args map[string]interface{}) (int, *ToolResult) {
}
// parseI2CBus extracts and validates an I2C bus from args
func parseI2CBus(args map[string]interface{}) (string, *ToolResult) {
//
//nolint:unused // Used by i2c_linux.go
func parseI2CBus(args map[string]any) (string, *ToolResult) {
bus, ok := args["bus"].(string)
if !ok || bus == "" {
return "", ErrorResult("bus is required (e.g. \"1\" for /dev/i2c-1)")
+13 -9
View File
@@ -74,7 +74,7 @@ func smbusProbe(fd int, addr int, hasQuick bool) bool {
// scan probes valid 7-bit addresses on a bus for connected devices.
// Uses the same hybrid probe strategy as i2cdetect's MODE_AUTO:
// SMBus Quick Write for most addresses, SMBus Read Byte for EEPROM ranges.
func (t *I2CTool) scan(args map[string]interface{}) *ToolResult {
func (t *I2CTool) scan(args map[string]any) *ToolResult {
bus, errResult := parseI2CBus(args)
if errResult != nil {
return errResult
@@ -99,7 +99,9 @@ func (t *I2CTool) scan(args map[string]interface{}) *ToolResult {
hasReadByte := funcs&i2cFuncSmbusReadByte != 0
if !hasQuick && !hasReadByte {
return ErrorResult(fmt.Sprintf("I2C adapter %s supports neither SMBus Quick nor Read Byte — cannot probe safely", devPath))
return ErrorResult(
fmt.Sprintf("I2C adapter %s supports neither SMBus Quick nor Read Byte — cannot probe safely", devPath),
)
}
type deviceEntry struct {
@@ -133,7 +135,7 @@ func (t *I2CTool) scan(args map[string]interface{}) *ToolResult {
return SilentResult(fmt.Sprintf("No devices found on %s. Check wiring and pull-up resistors.", devPath))
}
result, _ := json.MarshalIndent(map[string]interface{}{
result, _ := json.MarshalIndent(map[string]any{
"bus": devPath,
"devices": found,
"count": len(found),
@@ -142,7 +144,7 @@ func (t *I2CTool) scan(args map[string]interface{}) *ToolResult {
}
// readDevice reads bytes from an I2C device, optionally at a specific register
func (t *I2CTool) readDevice(args map[string]interface{}) *ToolResult {
func (t *I2CTool) readDevice(args map[string]any) *ToolResult {
bus, errResult := parseI2CBus(args)
if errResult != nil {
return errResult
@@ -180,7 +182,7 @@ func (t *I2CTool) readDevice(args map[string]interface{}) *ToolResult {
if reg < 0 || reg > 255 {
return ErrorResult("register must be between 0x00 and 0xFF")
}
_, err := syscall.Write(fd, []byte{byte(reg)})
_, err = syscall.Write(fd, []byte{byte(reg)})
if err != nil {
return ErrorResult(fmt.Sprintf("failed to write register 0x%02x: %v", reg, err))
}
@@ -201,7 +203,7 @@ func (t *I2CTool) readDevice(args map[string]interface{}) *ToolResult {
intBytes[i] = int(buf[i])
}
result, _ := json.MarshalIndent(map[string]interface{}{
result, _ := json.MarshalIndent(map[string]any{
"bus": devPath,
"address": fmt.Sprintf("0x%02x", addr),
"bytes": intBytes,
@@ -212,10 +214,12 @@ func (t *I2CTool) readDevice(args map[string]interface{}) *ToolResult {
}
// writeDevice writes bytes to an I2C device, optionally at a specific register
func (t *I2CTool) writeDevice(args map[string]interface{}) *ToolResult {
func (t *I2CTool) writeDevice(args map[string]any) *ToolResult {
confirm, _ := args["confirm"].(bool)
if !confirm {
return ErrorResult("write operations require confirm: true. Please confirm with the user before writing to I2C devices, as incorrect writes can misconfigure hardware.")
return ErrorResult(
"write operations require confirm: true. Please confirm with the user before writing to I2C devices, as incorrect writes can misconfigure hardware.",
)
}
bus, errResult := parseI2CBus(args)
@@ -228,7 +232,7 @@ func (t *I2CTool) writeDevice(args map[string]interface{}) *ToolResult {
return errResult
}
dataRaw, ok := args["data"].([]interface{})
dataRaw, ok := args["data"].([]any)
if !ok || len(dataRaw) == 0 {
return ErrorResult("data is required for write (array of byte values 0-255)")
}
+3 -3
View File
@@ -3,16 +3,16 @@
package tools
// scan is a stub for non-Linux platforms.
func (t *I2CTool) scan(args map[string]interface{}) *ToolResult {
func (t *I2CTool) scan(args map[string]any) *ToolResult {
return ErrorResult("I2C is only supported on Linux")
}
// readDevice is a stub for non-Linux platforms.
func (t *I2CTool) readDevice(args map[string]interface{}) *ToolResult {
func (t *I2CTool) readDevice(args map[string]any) *ToolResult {
return ErrorResult("I2C is only supported on Linux")
}
// writeDevice is a stub for non-Linux platforms.
func (t *I2CTool) writeDevice(args map[string]interface{}) *ToolResult {
func (t *I2CTool) writeDevice(args map[string]any) *ToolResult {
return ErrorResult("I2C is only supported on Linux")
}
+7 -7
View File
@@ -26,19 +26,19 @@ func (t *MessageTool) Description() string {
return "Send a message to user on a chat channel. Use this when you want to communicate something."
}
func (t *MessageTool) Parameters() map[string]interface{} {
return map[string]interface{}{
func (t *MessageTool) Parameters() map[string]any {
return map[string]any{
"type": "object",
"properties": map[string]interface{}{
"content": map[string]interface{}{
"properties": map[string]any{
"content": map[string]any{
"type": "string",
"description": "The message content to send",
},
"channel": map[string]interface{}{
"channel": map[string]any{
"type": "string",
"description": "Optional: target channel (telegram, whatsapp, etc.)",
},
"chat_id": map[string]interface{}{
"chat_id": map[string]any{
"type": "string",
"description": "Optional: target chat/user ID",
},
@@ -62,7 +62,7 @@ func (t *MessageTool) SetSendCallback(callback SendCallback) {
t.sendCallback = callback
}
func (t *MessageTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
func (t *MessageTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
content, ok := args["content"].(string)
if !ok {
return &ToolResult{ForLLM: "content is required", IsError: true}
+10 -10
View File
@@ -19,7 +19,7 @@ func TestMessageTool_Execute_Success(t *testing.T) {
})
ctx := context.Background()
args := map[string]interface{}{
args := map[string]any{
"content": "Hello, world!",
}
@@ -70,7 +70,7 @@ func TestMessageTool_Execute_WithCustomChannel(t *testing.T) {
})
ctx := context.Background()
args := map[string]interface{}{
args := map[string]any{
"content": "Test message",
"channel": "custom-channel",
"chat_id": "custom-chat-id",
@@ -104,7 +104,7 @@ func TestMessageTool_Execute_SendFailure(t *testing.T) {
})
ctx := context.Background()
args := map[string]interface{}{
args := map[string]any{
"content": "Test message",
}
@@ -136,7 +136,7 @@ func TestMessageTool_Execute_MissingContent(t *testing.T) {
tool.SetContext("test-channel", "test-chat-id")
ctx := context.Background()
args := map[string]interface{}{} // content missing
args := map[string]any{} // content missing
result := tool.Execute(ctx, args)
@@ -158,7 +158,7 @@ func TestMessageTool_Execute_NoTargetChannel(t *testing.T) {
})
ctx := context.Background()
args := map[string]interface{}{
args := map[string]any{
"content": "Test message",
}
@@ -179,7 +179,7 @@ func TestMessageTool_Execute_NotConfigured(t *testing.T) {
// No SetSendCallback called
ctx := context.Background()
args := map[string]interface{}{
args := map[string]any{
"content": "Test message",
}
@@ -219,7 +219,7 @@ func TestMessageTool_Parameters(t *testing.T) {
t.Error("Expected type 'object'")
}
props, ok := params["properties"].(map[string]interface{})
props, ok := params["properties"].(map[string]any)
if !ok {
t.Fatal("Expected properties to be a map")
}
@@ -231,7 +231,7 @@ func TestMessageTool_Parameters(t *testing.T) {
}
// Check content property
contentProp, ok := props["content"].(map[string]interface{})
contentProp, ok := props["content"].(map[string]any)
if !ok {
t.Error("Expected 'content' property")
}
@@ -240,7 +240,7 @@ func TestMessageTool_Parameters(t *testing.T) {
}
// Check channel property (optional)
channelProp, ok := props["channel"].(map[string]interface{})
channelProp, ok := props["channel"].(map[string]any)
if !ok {
t.Error("Expected 'channel' property")
}
@@ -249,7 +249,7 @@ func TestMessageTool_Parameters(t *testing.T) {
}
// Check chat_id property (optional)
chatIDProp, ok := props["chat_id"].(map[string]interface{})
chatIDProp, ok := props["chat_id"].(map[string]any)
if !ok {
t.Error("Expected 'chat_id' property")
}
+44 -23
View File
@@ -3,6 +3,7 @@ package tools
import (
"context"
"fmt"
"sort"
"sync"
"time"
@@ -34,16 +35,22 @@ func (r *ToolRegistry) Get(name string) (Tool, bool) {
return tool, ok
}
func (r *ToolRegistry) Execute(ctx context.Context, name string, args map[string]interface{}) *ToolResult {
func (r *ToolRegistry) Execute(ctx context.Context, name string, args map[string]any) *ToolResult {
return r.ExecuteWithContext(ctx, name, args, "", "", nil)
}
// ExecuteWithContext executes a tool with channel/chatID context and optional async callback.
// If the tool implements AsyncTool and a non-nil callback is provided,
// the callback will be set on the tool before execution.
func (r *ToolRegistry) ExecuteWithContext(ctx context.Context, name string, args map[string]interface{}, channel, chatID string, asyncCallback AsyncCallback) *ToolResult {
func (r *ToolRegistry) ExecuteWithContext(
ctx context.Context,
name string,
args map[string]any,
channel, chatID string,
asyncCallback AsyncCallback,
) *ToolResult {
logger.InfoCF("tool", "Tool execution started",
map[string]interface{}{
map[string]any{
"tool": name,
"args": args,
})
@@ -51,7 +58,7 @@ func (r *ToolRegistry) ExecuteWithContext(ctx context.Context, name string, args
tool, ok := r.Get(name)
if !ok {
logger.ErrorCF("tool", "Tool not found",
map[string]interface{}{
map[string]any{
"tool": name,
})
return ErrorResult(fmt.Sprintf("tool %q not found", name)).WithError(fmt.Errorf("tool not found"))
@@ -66,7 +73,7 @@ func (r *ToolRegistry) ExecuteWithContext(ctx context.Context, name string, args
if asyncTool, ok := tool.(AsyncTool); ok && asyncCallback != nil {
asyncTool.SetCallback(asyncCallback)
logger.DebugCF("tool", "Async callback injected",
map[string]interface{}{
map[string]any{
"tool": name,
})
}
@@ -78,20 +85,20 @@ func (r *ToolRegistry) ExecuteWithContext(ctx context.Context, name string, args
// Log based on result type
if result.IsError {
logger.ErrorCF("tool", "Tool execution failed",
map[string]interface{}{
map[string]any{
"tool": name,
"duration": duration.Milliseconds(),
"error": result.ForLLM,
})
} else if result.Async {
logger.InfoCF("tool", "Tool started (async)",
map[string]interface{}{
map[string]any{
"tool": name,
"duration": duration.Milliseconds(),
})
} else {
logger.InfoCF("tool", "Tool execution completed",
map[string]interface{}{
map[string]any{
"tool": name,
"duration_ms": duration.Milliseconds(),
"result_length": len(result.ForLLM),
@@ -101,13 +108,27 @@ func (r *ToolRegistry) ExecuteWithContext(ctx context.Context, name string, args
return result
}
func (r *ToolRegistry) GetDefinitions() []map[string]interface{} {
// sortedToolNames returns tool names in sorted order for deterministic iteration.
// This is critical for KV cache stability: non-deterministic map iteration would
// produce different system prompts and tool definitions on each call, invalidating
// the LLM's prefix cache even when no tools have changed.
func (r *ToolRegistry) sortedToolNames() []string {
names := make([]string, 0, len(r.tools))
for name := range r.tools {
names = append(names, name)
}
sort.Strings(names)
return names
}
func (r *ToolRegistry) GetDefinitions() []map[string]any {
r.mu.RLock()
defer r.mu.RUnlock()
definitions := make([]map[string]interface{}, 0, len(r.tools))
for _, tool := range r.tools {
definitions = append(definitions, ToolToSchema(tool))
sorted := r.sortedToolNames()
definitions := make([]map[string]any, 0, len(sorted))
for _, name := range sorted {
definitions = append(definitions, ToolToSchema(r.tools[name]))
}
return definitions
}
@@ -118,19 +139,21 @@ func (r *ToolRegistry) ToProviderDefs() []providers.ToolDefinition {
r.mu.RLock()
defer r.mu.RUnlock()
definitions := make([]providers.ToolDefinition, 0, len(r.tools))
for _, tool := range r.tools {
sorted := r.sortedToolNames()
definitions := make([]providers.ToolDefinition, 0, len(sorted))
for _, name := range sorted {
tool := r.tools[name]
schema := ToolToSchema(tool)
// Safely extract nested values with type checks
fn, ok := schema["function"].(map[string]interface{})
fn, ok := schema["function"].(map[string]any)
if !ok {
continue
}
name, _ := fn["name"].(string)
desc, _ := fn["description"].(string)
params, _ := fn["parameters"].(map[string]interface{})
params, _ := fn["parameters"].(map[string]any)
definitions = append(definitions, providers.ToolDefinition{
Type: "function",
@@ -149,11 +172,7 @@ func (r *ToolRegistry) List() []string {
r.mu.RLock()
defer r.mu.RUnlock()
names := make([]string, 0, len(r.tools))
for name := range r.tools {
names = append(names, name)
}
return names
return r.sortedToolNames()
}
// Count returns the number of registered tools.
@@ -169,8 +188,10 @@ func (r *ToolRegistry) GetSummaries() []string {
r.mu.RLock()
defer r.mu.RUnlock()
summaries := make([]string, 0, len(r.tools))
for _, tool := range r.tools {
sorted := r.sortedToolNames()
summaries := make([]string, 0, len(sorted))
for _, name := range sorted {
tool := r.tools[name]
summaries = append(summaries, fmt.Sprintf("- `%s` - %s", tool.Name(), tool.Description()))
}
return summaries
+350
View File
@@ -0,0 +1,350 @@
package tools
import (
"context"
"strings"
"sync"
"testing"
"github.com/sipeed/picoclaw/pkg/providers"
)
// --- mock types ---
type mockRegistryTool struct {
name string
desc string
params map[string]any
result *ToolResult
}
func (m *mockRegistryTool) Name() string { return m.name }
func (m *mockRegistryTool) Description() string { return m.desc }
func (m *mockRegistryTool) Parameters() map[string]any { return m.params }
func (m *mockRegistryTool) Execute(_ context.Context, _ map[string]any) *ToolResult {
return m.result
}
type mockCtxTool struct {
mockRegistryTool
channel string
chatID string
}
func (m *mockCtxTool) SetContext(channel, chatID string) {
m.channel = channel
m.chatID = chatID
}
type mockAsyncRegistryTool struct {
mockRegistryTool
cb AsyncCallback
}
func (m *mockAsyncRegistryTool) SetCallback(cb AsyncCallback) {
m.cb = cb
}
// --- helpers ---
func newMockTool(name, desc string) *mockRegistryTool {
return &mockRegistryTool{
name: name,
desc: desc,
params: map[string]any{"type": "object"},
result: SilentResult("ok"),
}
}
// --- tests ---
func TestNewToolRegistry(t *testing.T) {
r := NewToolRegistry()
if r.Count() != 0 {
t.Errorf("expected empty registry, got count %d", r.Count())
}
if len(r.List()) != 0 {
t.Errorf("expected empty list, got %v", r.List())
}
}
func TestToolRegistry_RegisterAndGet(t *testing.T) {
r := NewToolRegistry()
tool := newMockTool("echo", "echoes input")
r.Register(tool)
got, ok := r.Get("echo")
if !ok {
t.Fatal("expected to find registered tool")
}
if got.Name() != "echo" {
t.Errorf("expected name 'echo', got %q", got.Name())
}
}
func TestToolRegistry_Get_NotFound(t *testing.T) {
r := NewToolRegistry()
_, ok := r.Get("nonexistent")
if ok {
t.Error("expected ok=false for unregistered tool")
}
}
func TestToolRegistry_RegisterOverwrite(t *testing.T) {
r := NewToolRegistry()
r.Register(newMockTool("dup", "first"))
r.Register(newMockTool("dup", "second"))
if r.Count() != 1 {
t.Errorf("expected count 1 after overwrite, got %d", r.Count())
}
tool, _ := r.Get("dup")
if tool.Description() != "second" {
t.Errorf("expected overwritten description 'second', got %q", tool.Description())
}
}
func TestToolRegistry_Execute_Success(t *testing.T) {
r := NewToolRegistry()
r.Register(&mockRegistryTool{
name: "greet",
desc: "says hello",
params: map[string]any{},
result: SilentResult("hello"),
})
result := r.Execute(context.Background(), "greet", nil)
if result.IsError {
t.Errorf("expected success, got error: %s", result.ForLLM)
}
if result.ForLLM != "hello" {
t.Errorf("expected ForLLM 'hello', got %q", result.ForLLM)
}
}
func TestToolRegistry_Execute_NotFound(t *testing.T) {
r := NewToolRegistry()
result := r.Execute(context.Background(), "missing", nil)
if !result.IsError {
t.Error("expected error for missing tool")
}
if !strings.Contains(result.ForLLM, "not found") {
t.Errorf("expected 'not found' in error, got %q", result.ForLLM)
}
if result.Err == nil {
t.Error("expected Err to be set via WithError")
}
}
func TestToolRegistry_ExecuteWithContext_ContextualTool(t *testing.T) {
r := NewToolRegistry()
ct := &mockCtxTool{
mockRegistryTool: *newMockTool("ctx_tool", "needs context"),
}
r.Register(ct)
r.ExecuteWithContext(context.Background(), "ctx_tool", nil, "telegram", "chat-42", nil)
if ct.channel != "telegram" {
t.Errorf("expected channel 'telegram', got %q", ct.channel)
}
if ct.chatID != "chat-42" {
t.Errorf("expected chatID 'chat-42', got %q", ct.chatID)
}
}
func TestToolRegistry_ExecuteWithContext_SkipsEmptyContext(t *testing.T) {
r := NewToolRegistry()
ct := &mockCtxTool{
mockRegistryTool: *newMockTool("ctx_tool", "needs context"),
}
r.Register(ct)
r.ExecuteWithContext(context.Background(), "ctx_tool", nil, "", "", nil)
if ct.channel != "" || ct.chatID != "" {
t.Error("SetContext should not be called with empty channel/chatID")
}
}
func TestToolRegistry_ExecuteWithContext_AsyncCallback(t *testing.T) {
r := NewToolRegistry()
at := &mockAsyncRegistryTool{
mockRegistryTool: *newMockTool("async_tool", "async work"),
}
at.result = AsyncResult("started")
r.Register(at)
called := false
cb := func(_ context.Context, _ *ToolResult) { called = true }
result := r.ExecuteWithContext(context.Background(), "async_tool", nil, "", "", cb)
if at.cb == nil {
t.Error("expected SetCallback to have been called")
}
if !result.Async {
t.Error("expected async result")
}
at.cb(context.Background(), SilentResult("done"))
if !called {
t.Error("expected callback to be invoked")
}
}
func TestToolRegistry_GetDefinitions(t *testing.T) {
r := NewToolRegistry()
r.Register(newMockTool("alpha", "tool A"))
defs := r.GetDefinitions()
if len(defs) != 1 {
t.Fatalf("expected 1 definition, got %d", len(defs))
}
if defs[0]["type"] != "function" {
t.Errorf("expected type 'function', got %v", defs[0]["type"])
}
fn, ok := defs[0]["function"].(map[string]any)
if !ok {
t.Fatal("expected 'function' key to be a map")
}
if fn["name"] != "alpha" {
t.Errorf("expected name 'alpha', got %v", fn["name"])
}
if fn["description"] != "tool A" {
t.Errorf("expected description 'tool A', got %v", fn["description"])
}
}
func TestToolRegistry_ToProviderDefs(t *testing.T) {
r := NewToolRegistry()
params := map[string]any{"type": "object", "properties": map[string]any{}}
r.Register(&mockRegistryTool{
name: "beta",
desc: "tool B",
params: params,
result: SilentResult("ok"),
})
defs := r.ToProviderDefs()
if len(defs) != 1 {
t.Fatalf("expected 1 provider def, got %d", len(defs))
}
want := providers.ToolDefinition{
Type: "function",
Function: providers.ToolFunctionDefinition{
Name: "beta",
Description: "tool B",
Parameters: params,
},
}
got := defs[0]
if got.Type != want.Type {
t.Errorf("Type: want %q, got %q", want.Type, got.Type)
}
if got.Function.Name != want.Function.Name {
t.Errorf("Name: want %q, got %q", want.Function.Name, got.Function.Name)
}
if got.Function.Description != want.Function.Description {
t.Errorf("Description: want %q, got %q", want.Function.Description, got.Function.Description)
}
}
func TestToolRegistry_List(t *testing.T) {
r := NewToolRegistry()
r.Register(newMockTool("x", ""))
r.Register(newMockTool("y", ""))
names := r.List()
if len(names) != 2 {
t.Fatalf("expected 2 names, got %d", len(names))
}
nameSet := map[string]bool{}
for _, n := range names {
nameSet[n] = true
}
if !nameSet["x"] || !nameSet["y"] {
t.Errorf("expected names {x, y}, got %v", names)
}
}
func TestToolRegistry_Count(t *testing.T) {
r := NewToolRegistry()
if r.Count() != 0 {
t.Errorf("expected 0, got %d", r.Count())
}
r.Register(newMockTool("a", ""))
r.Register(newMockTool("b", ""))
if r.Count() != 2 {
t.Errorf("expected 2, got %d", r.Count())
}
r.Register(newMockTool("a", "replaced"))
if r.Count() != 2 {
t.Errorf("expected 2 after overwrite, got %d", r.Count())
}
}
func TestToolRegistry_GetSummaries(t *testing.T) {
r := NewToolRegistry()
r.Register(newMockTool("read_file", "Reads a file"))
summaries := r.GetSummaries()
if len(summaries) != 1 {
t.Fatalf("expected 1 summary, got %d", len(summaries))
}
if !strings.Contains(summaries[0], "`read_file`") {
t.Errorf("expected backtick-quoted name in summary, got %q", summaries[0])
}
if !strings.Contains(summaries[0], "Reads a file") {
t.Errorf("expected description in summary, got %q", summaries[0])
}
}
func TestToolToSchema(t *testing.T) {
tool := newMockTool("demo", "demo tool")
schema := ToolToSchema(tool)
if schema["type"] != "function" {
t.Errorf("expected type 'function', got %v", schema["type"])
}
fn, ok := schema["function"].(map[string]any)
if !ok {
t.Fatal("expected 'function' to be a map")
}
if fn["name"] != "demo" {
t.Errorf("expected name 'demo', got %v", fn["name"])
}
if fn["description"] != "demo tool" {
t.Errorf("expected description 'demo tool', got %v", fn["description"])
}
if fn["parameters"] == nil {
t.Error("expected parameters to be set")
}
}
func TestToolRegistry_ConcurrentAccess(t *testing.T) {
r := NewToolRegistry()
var wg sync.WaitGroup
for i := 0; i < 50; i++ {
wg.Add(1)
go func(n int) {
defer wg.Done()
name := string(rune('A' + n%26))
r.Register(newMockTool(name, "concurrent"))
r.Get(name)
r.Count()
r.List()
r.GetDefinitions()
}(i)
}
wg.Wait()
if r.Count() == 0 {
t.Error("expected tools to be registered after concurrent access")
}
}
+1 -1
View File
@@ -192,7 +192,7 @@ func TestToolResultJSONStructure(t *testing.T) {
}
// Verify JSON structure
var parsed map[string]interface{}
var parsed map[string]any
if err := json.Unmarshal(data, &parsed); err != nil {
t.Fatalf("Failed to parse JSON: %v", err)
}
+45 -13
View File
@@ -3,6 +3,7 @@ package tools
import (
"bytes"
"context"
"errors"
"fmt"
"os"
"os/exec"
@@ -75,11 +76,11 @@ func NewExecTool(workingDir string, restrict bool) *ExecTool {
func NewExecToolWithConfig(workingDir string, restrict bool, config *config.Config) *ExecTool {
denyPatterns := make([]*regexp.Regexp, 0)
enableDenyPatterns := true
if config != nil {
execConfig := config.Tools.Exec
enableDenyPatterns = execConfig.EnableDenyPatterns
enableDenyPatterns := execConfig.EnableDenyPatterns
if enableDenyPatterns {
denyPatterns = append(denyPatterns, defaultDenyPatterns...)
if len(execConfig.CustomDenyPatterns) > 0 {
fmt.Printf("Using custom deny patterns: %v\n", execConfig.CustomDenyPatterns)
for _, pattern := range execConfig.CustomDenyPatterns {
@@ -90,8 +91,6 @@ func NewExecToolWithConfig(workingDir string, restrict bool, config *config.Conf
}
denyPatterns = append(denyPatterns, re)
}
} else {
denyPatterns = append(denyPatterns, defaultDenyPatterns...)
}
} else {
// If deny patterns are disabled, we won't add any patterns, allowing all commands.
@@ -118,15 +117,15 @@ func (t *ExecTool) Description() string {
return "Execute a shell command and return its output. Use with caution."
}
func (t *ExecTool) Parameters() map[string]interface{} {
return map[string]interface{}{
func (t *ExecTool) Parameters() map[string]any {
return map[string]any{
"type": "object",
"properties": map[string]interface{}{
"command": map[string]interface{}{
"properties": map[string]any{
"command": map[string]any{
"type": "string",
"description": "The shell command to execute",
},
"working_dir": map[string]interface{}{
"working_dir": map[string]any{
"type": "string",
"description": "Optional working directory for the command",
},
@@ -135,7 +134,7 @@ func (t *ExecTool) Parameters() map[string]interface{} {
}
}
func (t *ExecTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
func (t *ExecTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
command, ok := args["command"].(string)
if !ok {
return ErrorResult("command is required")
@@ -143,7 +142,15 @@ func (t *ExecTool) Execute(ctx context.Context, args map[string]interface{}) *To
cwd := t.workingDir
if wd, ok := args["working_dir"].(string); ok && wd != "" {
cwd = wd
if t.restrictToWorkspace && t.workingDir != "" {
resolvedWD, err := validatePath(wd, t.workingDir, true)
if err != nil {
return ErrorResult("Command blocked by safety guard (" + err.Error() + ")")
}
cwd = resolvedWD
} else {
cwd = wd
}
}
if cwd == "" {
@@ -177,18 +184,43 @@ func (t *ExecTool) Execute(ctx context.Context, args map[string]interface{}) *To
cmd.Dir = cwd
}
prepareCommandForTermination(cmd)
var stdout, stderr bytes.Buffer
cmd.Stdout = &stdout
cmd.Stderr = &stderr
err := cmd.Run()
if err := cmd.Start(); err != nil {
return ErrorResult(fmt.Sprintf("failed to start command: %v", err))
}
done := make(chan error, 1)
go func() {
done <- cmd.Wait()
}()
var err error
select {
case err = <-done:
case <-cmdCtx.Done():
_ = terminateProcessTree(cmd)
select {
case err = <-done:
case <-time.After(2 * time.Second):
if cmd.Process != nil {
_ = cmd.Process.Kill()
}
err = <-done
}
}
output := stdout.String()
if stderr.Len() > 0 {
output += "\nSTDERR:\n" + stderr.String()
}
if err != nil {
if cmdCtx.Err() == context.DeadlineExceeded {
if errors.Is(cmdCtx.Err(), context.DeadlineExceeded) {
msg := fmt.Sprintf("Command timed out after %v", t.timeout)
return &ToolResult{
ForLLM: msg,
+32
View File
@@ -0,0 +1,32 @@
//go:build !windows
package tools
import (
"os/exec"
"syscall"
)
func prepareCommandForTermination(cmd *exec.Cmd) {
if cmd == nil {
return
}
cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
}
func terminateProcessTree(cmd *exec.Cmd) error {
if cmd == nil || cmd.Process == nil {
return nil
}
pid := cmd.Process.Pid
if pid <= 0 {
return nil
}
// Kill the entire process group spawned by the shell command.
_ = syscall.Kill(-pid, syscall.SIGKILL)
// Fallback kill on the shell process itself.
_ = cmd.Process.Kill()
return nil
}
+27
View File
@@ -0,0 +1,27 @@
//go:build windows
package tools
import (
"os/exec"
"strconv"
)
func prepareCommandForTermination(cmd *exec.Cmd) {
// no-op on Windows
}
func terminateProcessTree(cmd *exec.Cmd) error {
if cmd == nil || cmd.Process == nil {
return nil
}
pid := cmd.Process.Pid
if pid <= 0 {
return nil
}
_ = exec.Command("taskkill", "/T", "/F", "/PID", strconv.Itoa(pid)).Run()
_ = cmd.Process.Kill()
return nil
}
+75 -11
View File
@@ -14,7 +14,7 @@ func TestShellTool_Success(t *testing.T) {
tool := NewExecTool("", false)
ctx := context.Background()
args := map[string]interface{}{
args := map[string]any{
"command": "echo 'hello world'",
}
@@ -41,7 +41,7 @@ func TestShellTool_Failure(t *testing.T) {
tool := NewExecTool("", false)
ctx := context.Background()
args := map[string]interface{}{
args := map[string]any{
"command": "ls /nonexistent_directory_12345",
}
@@ -69,7 +69,7 @@ func TestShellTool_Timeout(t *testing.T) {
tool.SetTimeout(100 * time.Millisecond)
ctx := context.Background()
args := map[string]interface{}{
args := map[string]any{
"command": "sleep 10",
}
@@ -91,12 +91,12 @@ func TestShellTool_WorkingDir(t *testing.T) {
// Create temp directory
tmpDir := t.TempDir()
testFile := filepath.Join(tmpDir, "test.txt")
os.WriteFile(testFile, []byte("test content"), 0644)
os.WriteFile(testFile, []byte("test content"), 0o644)
tool := NewExecTool("", false)
ctx := context.Background()
args := map[string]interface{}{
args := map[string]any{
"command": "cat test.txt",
"working_dir": tmpDir,
}
@@ -117,7 +117,7 @@ func TestShellTool_DangerousCommand(t *testing.T) {
tool := NewExecTool("", false)
ctx := context.Background()
args := map[string]interface{}{
args := map[string]any{
"command": "rm -rf /",
}
@@ -138,7 +138,7 @@ func TestShellTool_MissingCommand(t *testing.T) {
tool := NewExecTool("", false)
ctx := context.Background()
args := map[string]interface{}{}
args := map[string]any{}
result := tool.Execute(ctx, args)
@@ -153,7 +153,7 @@ func TestShellTool_StderrCapture(t *testing.T) {
tool := NewExecTool("", false)
ctx := context.Background()
args := map[string]interface{}{
args := map[string]any{
"command": "sh -c 'echo stdout; echo stderr >&2'",
}
@@ -174,7 +174,7 @@ func TestShellTool_OutputTruncation(t *testing.T) {
ctx := context.Background()
// Generate long output (>10000 chars)
args := map[string]interface{}{
args := map[string]any{
"command": "python3 -c \"print('x' * 20000)\" || echo " + strings.Repeat("x", 20000),
}
@@ -186,6 +186,66 @@ func TestShellTool_OutputTruncation(t *testing.T) {
}
}
// TestShellTool_WorkingDir_OutsideWorkspace verifies that working_dir cannot escape the workspace directly
func TestShellTool_WorkingDir_OutsideWorkspace(t *testing.T) {
root := t.TempDir()
workspace := filepath.Join(root, "workspace")
outsideDir := filepath.Join(root, "outside")
if err := os.MkdirAll(workspace, 0o755); err != nil {
t.Fatalf("failed to create workspace: %v", err)
}
if err := os.MkdirAll(outsideDir, 0o755); err != nil {
t.Fatalf("failed to create outside dir: %v", err)
}
tool := NewExecTool(workspace, true)
result := tool.Execute(context.Background(), map[string]any{
"command": "pwd",
"working_dir": outsideDir,
})
if !result.IsError {
t.Fatalf("expected working_dir outside workspace to be blocked, got output: %s", result.ForLLM)
}
if !strings.Contains(result.ForLLM, "blocked") {
t.Errorf("expected 'blocked' in error, got: %s", result.ForLLM)
}
}
// TestShellTool_WorkingDir_SymlinkEscape verifies that a symlink inside the workspace
// pointing outside cannot be used as working_dir to escape the sandbox.
func TestShellTool_WorkingDir_SymlinkEscape(t *testing.T) {
root := t.TempDir()
workspace := filepath.Join(root, "workspace")
secretDir := filepath.Join(root, "secret")
if err := os.MkdirAll(workspace, 0o755); err != nil {
t.Fatalf("failed to create workspace: %v", err)
}
if err := os.MkdirAll(secretDir, 0o755); err != nil {
t.Fatalf("failed to create secret dir: %v", err)
}
os.WriteFile(filepath.Join(secretDir, "secret.txt"), []byte("top secret"), 0o644)
// symlink lives inside the workspace but resolves to secretDir outside it
link := filepath.Join(workspace, "escape")
if err := os.Symlink(secretDir, link); err != nil {
t.Skipf("symlinks not supported in this environment: %v", err)
}
tool := NewExecTool(workspace, true)
result := tool.Execute(context.Background(), map[string]any{
"command": "cat secret.txt",
"working_dir": link,
})
if !result.IsError {
t.Fatalf("expected symlink working_dir escape to be blocked, got output: %s", result.ForLLM)
}
if !strings.Contains(result.ForLLM, "blocked") {
t.Errorf("expected 'blocked' in error, got: %s", result.ForLLM)
}
}
// TestShellTool_RestrictToWorkspace verifies workspace restriction
func TestShellTool_RestrictToWorkspace(t *testing.T) {
tmpDir := t.TempDir()
@@ -193,7 +253,7 @@ func TestShellTool_RestrictToWorkspace(t *testing.T) {
tool.SetRestrictToWorkspace(true)
ctx := context.Background()
args := map[string]interface{}{
args := map[string]any{
"command": "cat ../../etc/passwd",
}
@@ -205,6 +265,10 @@ func TestShellTool_RestrictToWorkspace(t *testing.T) {
}
if !strings.Contains(result.ForLLM, "blocked") && !strings.Contains(result.ForUser, "blocked") {
t.Errorf("Expected 'blocked' message for path traversal, got ForLLM: %s, ForUser: %s", result.ForLLM, result.ForUser)
t.Errorf(
"Expected 'blocked' message for path traversal, got ForLLM: %s, ForUser: %s",
result.ForLLM,
result.ForUser,
)
}
}
+61
View File
@@ -0,0 +1,61 @@
//go:build !windows
package tools
import (
"context"
"os"
"path/filepath"
"strconv"
"strings"
"syscall"
"testing"
"time"
)
func processExists(pid int) bool {
if pid <= 0 {
return false
}
err := syscall.Kill(pid, 0)
return err == nil || err == syscall.EPERM
}
func TestShellTool_TimeoutKillsChildProcess(t *testing.T) {
tool := NewExecTool(t.TempDir(), false)
tool.SetTimeout(500 * time.Millisecond)
args := map[string]any{
// Spawn a child process that would outlive the shell unless process-group kill is used.
"command": "sleep 60 & echo $! > child.pid; wait",
}
result := tool.Execute(context.Background(), args)
if !result.IsError {
t.Fatalf("expected timeout error, got success: %s", result.ForLLM)
}
if !strings.Contains(result.ForLLM, "timed out") {
t.Fatalf("expected timeout message, got: %s", result.ForLLM)
}
childPIDPath := filepath.Join(tool.workingDir, "child.pid")
data, err := os.ReadFile(childPIDPath)
if err != nil {
t.Fatalf("failed to read child pid file: %v", err)
}
childPID, err := strconv.Atoi(strings.TrimSpace(string(data)))
if err != nil {
t.Fatalf("failed to parse child pid: %v", err)
}
deadline := time.Now().Add(2 * time.Second)
for time.Now().Before(deadline) {
if !processExists(childPID) {
return
}
time.Sleep(50 * time.Millisecond)
}
t.Fatalf("child process %d is still running after timeout", childPID)
}
+201
View File
@@ -0,0 +1,201 @@
package tools
import (
"context"
"encoding/json"
"fmt"
"os"
"path/filepath"
"sync"
"time"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/skills"
"github.com/sipeed/picoclaw/pkg/utils"
)
// InstallSkillTool allows the LLM agent to install skills from registries.
// It shares the same RegistryManager that FindSkillsTool uses,
// so all registries configured in config are available for installation.
type InstallSkillTool struct {
registryMgr *skills.RegistryManager
workspace string
mu sync.Mutex
}
// NewInstallSkillTool creates a new InstallSkillTool.
// registryMgr is the shared registry manager (same instance as FindSkillsTool).
// workspace is the root workspace directory; skills install to {workspace}/skills/{slug}/.
func NewInstallSkillTool(registryMgr *skills.RegistryManager, workspace string) *InstallSkillTool {
return &InstallSkillTool{
registryMgr: registryMgr,
workspace: workspace,
mu: sync.Mutex{},
}
}
func (t *InstallSkillTool) Name() string {
return "install_skill"
}
func (t *InstallSkillTool) Description() string {
return "Install a skill from a registry by slug. Downloads and extracts the skill into the workspace. Use find_skills first to discover available skills."
}
func (t *InstallSkillTool) Parameters() map[string]any {
return map[string]any{
"type": "object",
"properties": map[string]any{
"slug": map[string]any{
"type": "string",
"description": "The unique slug of the skill to install (e.g., 'github', 'docker-compose')",
},
"version": map[string]any{
"type": "string",
"description": "Specific version to install (optional, defaults to latest)",
},
"registry": map[string]any{
"type": "string",
"description": "Registry to install from (required, e.g., 'clawhub')",
},
"force": map[string]any{
"type": "boolean",
"description": "Force reinstall if skill already exists (default false)",
},
},
"required": []string{"slug", "registry"},
}
}
func (t *InstallSkillTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
// Install lock to prevent concurrent directory operations.
// Ideally this should be done at a `slug` level, currently, its at a `workspace` level.
t.mu.Lock()
defer t.mu.Unlock()
// Validate slug
slug, _ := args["slug"].(string)
if err := utils.ValidateSkillIdentifier(slug); err != nil {
return ErrorResult(fmt.Sprintf("invalid slug %q: error: %s", slug, err.Error()))
}
// Validate registry
registryName, _ := args["registry"].(string)
if err := utils.ValidateSkillIdentifier(registryName); err != nil {
return ErrorResult(fmt.Sprintf("invalid registry %q: error: %s", registryName, err.Error()))
}
version, _ := args["version"].(string)
force, _ := args["force"].(bool)
// Check if already installed.
skillsDir := filepath.Join(t.workspace, "skills")
targetDir := filepath.Join(skillsDir, slug)
if !force {
if _, err := os.Stat(targetDir); err == nil {
return ErrorResult(
fmt.Sprintf("skill %q already installed at %s. Use force=true to reinstall.", slug, targetDir),
)
}
} else {
// Force: remove existing if present.
os.RemoveAll(targetDir)
}
// Resolve which registry to use.
registry := t.registryMgr.GetRegistry(registryName)
if registry == nil {
return ErrorResult(fmt.Sprintf("registry %q not found", registryName))
}
// Ensure skills directory exists.
if err := os.MkdirAll(skillsDir, 0o755); err != nil {
return ErrorResult(fmt.Sprintf("failed to create skills directory: %v", err))
}
// Download and install (handles metadata, version resolution, extraction).
result, err := registry.DownloadAndInstall(ctx, slug, version, targetDir)
if err != nil {
// Clean up partial install.
rmErr := os.RemoveAll(targetDir)
if rmErr != nil {
logger.ErrorCF("tool", "Failed to remove partial install",
map[string]any{
"tool": "install_skill",
"target_dir": targetDir,
"error": rmErr.Error(),
})
}
return ErrorResult(fmt.Sprintf("failed to install %q: %v", slug, err))
}
// Moderation: block malware.
if result.IsMalwareBlocked {
rmErr := os.RemoveAll(targetDir)
if rmErr != nil {
logger.ErrorCF("tool", "Failed to remove partial install",
map[string]any{
"tool": "install_skill",
"target_dir": targetDir,
"error": rmErr.Error(),
})
}
return ErrorResult(fmt.Sprintf("skill %q is flagged as malicious and cannot be installed", slug))
}
// Write origin metadata.
if err := writeOriginMeta(targetDir, registry.Name(), slug, result.Version); err != nil {
logger.ErrorCF("tool", "Failed to write origin metadata",
map[string]any{
"tool": "install_skill",
"error": err.Error(),
"target": targetDir,
"registry": registry.Name(),
"slug": slug,
"version": result.Version,
})
_ = err
}
// Build result with moderation warning if suspicious.
var output string
if result.IsSuspicious {
output = fmt.Sprintf("⚠️ Warning: skill %q is flagged as suspicious (may contain risky patterns).\n\n", slug)
}
output += fmt.Sprintf("Successfully installed skill %q v%s from %s registry.\nLocation: %s\n",
slug, result.Version, registry.Name(), targetDir)
if result.Summary != "" {
output += fmt.Sprintf("Description: %s\n", result.Summary)
}
output += "\nThe skill is now available and can be loaded in the current session."
return SilentResult(output)
}
// originMeta tracks which registry a skill was installed from.
type originMeta struct {
Version int `json:"version"`
Registry string `json:"registry"`
Slug string `json:"slug"`
InstalledVersion string `json:"installed_version"`
InstalledAt int64 `json:"installed_at"`
}
func writeOriginMeta(targetDir, registryName, slug, version string) error {
meta := originMeta{
Version: 1,
Registry: registryName,
Slug: slug,
InstalledVersion: version,
InstalledAt: time.Now().UnixMilli(),
}
data, err := json.MarshalIndent(meta, "", " ")
if err != nil {
return err
}
return os.WriteFile(filepath.Join(targetDir, ".skill-origin.json"), data, 0o644)
}
+104
View File
@@ -0,0 +1,104 @@
package tools
import (
"context"
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/sipeed/picoclaw/pkg/skills"
)
func TestInstallSkillToolName(t *testing.T) {
tool := NewInstallSkillTool(skills.NewRegistryManager(), t.TempDir())
assert.Equal(t, "install_skill", tool.Name())
}
func TestInstallSkillToolMissingSlug(t *testing.T) {
tool := NewInstallSkillTool(skills.NewRegistryManager(), t.TempDir())
result := tool.Execute(context.Background(), map[string]any{})
assert.True(t, result.IsError)
assert.Contains(t, result.ForLLM, "identifier is required and must be a non-empty string")
}
func TestInstallSkillToolEmptySlug(t *testing.T) {
tool := NewInstallSkillTool(skills.NewRegistryManager(), t.TempDir())
result := tool.Execute(context.Background(), map[string]any{
"slug": " ",
})
assert.True(t, result.IsError)
assert.Contains(t, result.ForLLM, "identifier is required and must be a non-empty string")
}
func TestInstallSkillToolUnsafeSlug(t *testing.T) {
tool := NewInstallSkillTool(skills.NewRegistryManager(), t.TempDir())
cases := []string{
"../etc/passwd",
"path/traversal",
"path\\traversal",
}
for _, slug := range cases {
result := tool.Execute(context.Background(), map[string]any{
"slug": slug,
})
assert.True(t, result.IsError, "slug %q should be rejected", slug)
assert.Contains(t, result.ForLLM, "invalid slug")
}
}
func TestInstallSkillToolAlreadyExists(t *testing.T) {
workspace := t.TempDir()
skillDir := filepath.Join(workspace, "skills", "existing-skill")
require.NoError(t, os.MkdirAll(skillDir, 0o755))
tool := NewInstallSkillTool(skills.NewRegistryManager(), workspace)
result := tool.Execute(context.Background(), map[string]any{
"slug": "existing-skill",
"registry": "clawhub",
})
assert.True(t, result.IsError)
assert.Contains(t, result.ForLLM, "already installed")
}
func TestInstallSkillToolRegistryNotFound(t *testing.T) {
workspace := t.TempDir()
tool := NewInstallSkillTool(skills.NewRegistryManager(), workspace)
result := tool.Execute(context.Background(), map[string]any{
"slug": "some-skill",
"registry": "nonexistent",
})
assert.True(t, result.IsError)
assert.Contains(t, result.ForLLM, "registry")
assert.Contains(t, result.ForLLM, "not found")
}
func TestInstallSkillToolParameters(t *testing.T) {
tool := NewInstallSkillTool(skills.NewRegistryManager(), t.TempDir())
params := tool.Parameters()
props, ok := params["properties"].(map[string]any)
assert.True(t, ok)
assert.Contains(t, props, "slug")
assert.Contains(t, props, "version")
assert.Contains(t, props, "registry")
assert.Contains(t, props, "force")
required, ok := params["required"].([]string)
assert.True(t, ok)
assert.Contains(t, required, "slug")
assert.Contains(t, required, "registry")
}
func TestInstallSkillToolMissingRegistry(t *testing.T) {
tool := NewInstallSkillTool(skills.NewRegistryManager(), t.TempDir())
result := tool.Execute(context.Background(), map[string]any{
"slug": "some-skill",
})
assert.True(t, result.IsError)
assert.Contains(t, result.ForLLM, "invalid registry")
}
+119
View File
@@ -0,0 +1,119 @@
package tools
import (
"context"
"fmt"
"strings"
"github.com/sipeed/picoclaw/pkg/skills"
)
// FindSkillsTool allows the LLM agent to search for installable skills from registries.
type FindSkillsTool struct {
registryMgr *skills.RegistryManager
cache *skills.SearchCache
}
// NewFindSkillsTool creates a new FindSkillsTool.
// registryMgr is the shared registry manager (built from config in createToolRegistry).
// cache is the search cache for deduplicating similar queries.
func NewFindSkillsTool(registryMgr *skills.RegistryManager, cache *skills.SearchCache) *FindSkillsTool {
return &FindSkillsTool{
registryMgr: registryMgr,
cache: cache,
}
}
func (t *FindSkillsTool) Name() string {
return "find_skills"
}
func (t *FindSkillsTool) Description() string {
return "Search for installable skills from skill registries. Returns skill slugs, descriptions, versions, and relevance scores. Use this to discover skills before installing them with install_skill."
}
func (t *FindSkillsTool) Parameters() map[string]any {
return map[string]any{
"type": "object",
"properties": map[string]any{
"query": map[string]any{
"type": "string",
"description": "Search query describing the desired skill capability (e.g., 'github integration', 'database management')",
},
"limit": map[string]any{
"type": "integer",
"description": "Maximum number of results to return (1-20, default 5)",
"minimum": 1.0,
"maximum": 20.0,
},
},
"required": []string{"query"},
}
}
func (t *FindSkillsTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
query, ok := args["query"].(string)
query = strings.ToLower(strings.TrimSpace(query))
if !ok || query == "" {
return ErrorResult("query is required and must be a non-empty string")
}
limit := 5
if l, ok := args["limit"].(float64); ok {
li := int(l)
if li >= 1 && li <= 20 {
limit = li
}
}
// Check cache first.
if t.cache != nil {
if cached, hit := t.cache.Get(query); hit {
return SilentResult(formatSearchResults(query, cached, true))
}
}
// Search all registries.
results, err := t.registryMgr.SearchAll(ctx, query, limit)
if err != nil {
return ErrorResult(fmt.Sprintf("skill search failed: %v", err))
}
// Cache the results.
if t.cache != nil && len(results) > 0 {
t.cache.Put(query, results)
}
return SilentResult(formatSearchResults(query, results, false))
}
func formatSearchResults(query string, results []skills.SearchResult, cached bool) string {
if len(results) == 0 {
return fmt.Sprintf("No skills found for query: %q", query)
}
var sb strings.Builder
source := ""
if cached {
source = " (cached)"
}
sb.WriteString(fmt.Sprintf("Found %d skills for %q%s:\n\n", len(results), query, source))
for i, r := range results {
sb.WriteString(fmt.Sprintf("%d. **%s**", i+1, r.Slug))
if r.Version != "" {
sb.WriteString(fmt.Sprintf(" v%s", r.Version))
}
sb.WriteString(fmt.Sprintf(" (score: %.3f, registry: %s)\n", r.Score, r.RegistryName))
if r.DisplayName != "" && r.DisplayName != r.Slug {
sb.WriteString(fmt.Sprintf(" Name: %s\n", r.DisplayName))
}
if r.Summary != "" {
sb.WriteString(fmt.Sprintf(" %s\n", r.Summary))
}
sb.WriteString("\n")
}
sb.WriteString("Use install_skill with the slug to install a skill.")
return sb.String()
}
+90
View File
@@ -0,0 +1,90 @@
package tools
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/sipeed/picoclaw/pkg/skills"
)
func TestFindSkillsToolName(t *testing.T) {
tool := NewFindSkillsTool(skills.NewRegistryManager(), nil)
assert.Equal(t, "find_skills", tool.Name())
}
func TestFindSkillsToolMissingQuery(t *testing.T) {
tool := NewFindSkillsTool(skills.NewRegistryManager(), nil)
result := tool.Execute(context.Background(), map[string]any{})
assert.True(t, result.IsError)
assert.Contains(t, result.ForLLM, "query is required")
}
func TestFindSkillsToolEmptyQuery(t *testing.T) {
tool := NewFindSkillsTool(skills.NewRegistryManager(), nil)
result := tool.Execute(context.Background(), map[string]any{
"query": " ",
})
assert.True(t, result.IsError)
}
func TestFindSkillsToolCacheHit(t *testing.T) {
cache := skills.NewSearchCache(10, 5*60*1000*1000*1000) // 5 min
cache.Put("github", []skills.SearchResult{
{Slug: "github", Score: 0.9, RegistryName: "clawhub"},
})
tool := NewFindSkillsTool(skills.NewRegistryManager(), cache)
result := tool.Execute(context.Background(), map[string]any{
"query": "github",
})
assert.False(t, result.IsError)
assert.Contains(t, result.ForLLM, "github")
assert.Contains(t, result.ForLLM, "cached")
}
func TestFindSkillsToolParameters(t *testing.T) {
tool := NewFindSkillsTool(skills.NewRegistryManager(), nil)
params := tool.Parameters()
props, ok := params["properties"].(map[string]any)
assert.True(t, ok)
assert.Contains(t, props, "query")
assert.Contains(t, props, "limit")
required, ok := params["required"].([]string)
assert.True(t, ok)
assert.Contains(t, required, "query")
}
func TestFindSkillsToolDescription(t *testing.T) {
tool := NewFindSkillsTool(skills.NewRegistryManager(), nil)
assert.NotEmpty(t, tool.Description())
assert.Contains(t, tool.Description(), "skill")
}
func TestFormatSearchResultsEmpty(t *testing.T) {
result := formatSearchResults("test query", nil, false)
assert.Contains(t, result, "No skills found")
}
func TestFormatSearchResultsWithData(t *testing.T) {
results := []skills.SearchResult{
{
Slug: "github",
Score: 0.95,
DisplayName: "GitHub",
Summary: "GitHub API integration",
Version: "1.0.0",
RegistryName: "clawhub",
},
}
output := formatSearchResults("github", results, false)
assert.Contains(t, output, "github")
assert.Contains(t, output, "v1.0.0")
assert.Contains(t, output, "0.950")
assert.Contains(t, output, "clawhub")
assert.Contains(t, output, "install_skill")
}
+10 -9
View File
@@ -3,6 +3,7 @@ package tools
import (
"context"
"fmt"
"strings"
)
type SpawnTool struct {
@@ -34,19 +35,19 @@ func (t *SpawnTool) Description() string {
return "Spawn a subagent to handle a task in the background. Use this for complex or time-consuming tasks that can run independently. The subagent will complete the task and report back when done."
}
func (t *SpawnTool) Parameters() map[string]interface{} {
return map[string]interface{}{
func (t *SpawnTool) Parameters() map[string]any {
return map[string]any{
"type": "object",
"properties": map[string]interface{}{
"task": map[string]interface{}{
"properties": map[string]any{
"task": map[string]any{
"type": "string",
"description": "The task for subagent to complete",
},
"label": map[string]interface{}{
"label": map[string]any{
"type": "string",
"description": "Optional short label for the task (for display)",
},
"agent_id": map[string]interface{}{
"agent_id": map[string]any{
"type": "string",
"description": "Optional target agent ID to delegate the task to",
},
@@ -64,10 +65,10 @@ func (t *SpawnTool) SetAllowlistChecker(check func(targetAgentID string) bool) {
t.allowlistCheck = check
}
func (t *SpawnTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
func (t *SpawnTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
task, ok := args["task"].(string)
if !ok {
return ErrorResult("task is required")
if !ok || strings.TrimSpace(task) == "" {
return ErrorResult("task is required and must be a non-empty string")
}
label, _ := args["label"].(string)
+79
View File
@@ -0,0 +1,79 @@
package tools
import (
"context"
"strings"
"testing"
)
func TestSpawnTool_Execute_EmptyTask(t *testing.T) {
provider := &MockLLMProvider{}
manager := NewSubagentManager(provider, "test-model", "/tmp/test", nil)
tool := NewSpawnTool(manager)
ctx := context.Background()
tests := []struct {
name string
args map[string]any
}{
{"empty string", map[string]any{"task": ""}},
{"whitespace only", map[string]any{"task": " "}},
{"tabs and newlines", map[string]any{"task": "\t\n "}},
{"missing task key", map[string]any{"label": "test"}},
{"wrong type", map[string]any{"task": 123}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := tool.Execute(ctx, tt.args)
if result == nil {
t.Fatal("Result should not be nil")
}
if !result.IsError {
t.Error("Expected error for invalid task parameter")
}
if !strings.Contains(result.ForLLM, "task is required") {
t.Errorf("Error message should mention 'task is required', got: %s", result.ForLLM)
}
})
}
}
func TestSpawnTool_Execute_ValidTask(t *testing.T) {
provider := &MockLLMProvider{}
manager := NewSubagentManager(provider, "test-model", "/tmp/test", nil)
tool := NewSpawnTool(manager)
ctx := context.Background()
args := map[string]any{
"task": "Write a haiku about coding",
"label": "haiku-task",
}
result := tool.Execute(ctx, args)
if result == nil {
t.Fatal("Result should not be nil")
}
if result.IsError {
t.Errorf("Expected success for valid task, got error: %s", result.ForLLM)
}
if !result.Async {
t.Error("SpawnTool should return async result")
}
}
func TestSpawnTool_Execute_NilManager(t *testing.T) {
tool := NewSpawnTool(nil)
ctx := context.Background()
args := map[string]any{"task": "test task"}
result := tool.Execute(ctx, args)
if !result.IsError {
t.Error("Expected error for nil manager")
}
if !strings.Contains(result.ForLLM, "Subagent manager not configured") {
t.Errorf("Error message should mention manager not configured, got: %s", result.ForLLM)
}
}
+21 -15
View File
@@ -24,41 +24,41 @@ func (t *SPITool) Description() string {
return "Interact with SPI bus devices for high-speed peripheral communication. Actions: list (find SPI devices), transfer (full-duplex send/receive), read (receive bytes). Linux only."
}
func (t *SPITool) Parameters() map[string]interface{} {
return map[string]interface{}{
func (t *SPITool) Parameters() map[string]any {
return map[string]any{
"type": "object",
"properties": map[string]interface{}{
"action": map[string]interface{}{
"properties": map[string]any{
"action": map[string]any{
"type": "string",
"enum": []string{"list", "transfer", "read"},
"description": "Action to perform: list (find available SPI devices), transfer (full-duplex send/receive), read (receive bytes by sending zeros)",
},
"device": map[string]interface{}{
"device": map[string]any{
"type": "string",
"description": "SPI device identifier (e.g. \"2.0\" for /dev/spidev2.0). Required for transfer/read.",
},
"speed": map[string]interface{}{
"speed": map[string]any{
"type": "integer",
"description": "SPI clock speed in Hz. Default: 1000000 (1 MHz).",
},
"mode": map[string]interface{}{
"mode": map[string]any{
"type": "integer",
"description": "SPI mode (0-3). Default: 0. Mode sets CPOL and CPHA: 0=0,0 1=0,1 2=1,0 3=1,1.",
},
"bits": map[string]interface{}{
"bits": map[string]any{
"type": "integer",
"description": "Bits per word. Default: 8.",
},
"data": map[string]interface{}{
"data": map[string]any{
"type": "array",
"items": map[string]interface{}{"type": "integer"},
"items": map[string]any{"type": "integer"},
"description": "Bytes to send (0-255 each). Required for transfer action.",
},
"length": map[string]interface{}{
"length": map[string]any{
"type": "integer",
"description": "Number of bytes to read (1-4096). Required for read action.",
},
"confirm": map[string]interface{}{
"confirm": map[string]any{
"type": "boolean",
"description": "Must be true for transfer operations. Safety guard to prevent accidental writes.",
},
@@ -67,7 +67,7 @@ func (t *SPITool) Parameters() map[string]interface{} {
}
}
func (t *SPITool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
func (t *SPITool) Execute(ctx context.Context, args map[string]any) *ToolResult {
if runtime.GOOS != "linux" {
return ErrorResult("SPI is only supported on Linux. This tool requires /dev/spidev* device files.")
}
@@ -97,7 +97,9 @@ func (t *SPITool) list() *ToolResult {
}
if len(matches) == 0 {
return SilentResult("No SPI devices found. You may need to:\n1. Enable SPI in device tree\n2. Configure pinmux for your board (see hardware skill)\n3. Check that spidev module is loaded")
return SilentResult(
"No SPI devices found. You may need to:\n1. Enable SPI in device tree\n2. Configure pinmux for your board (see hardware skill)\n3. Check that spidev module is loaded",
)
}
type devInfo struct {
@@ -117,8 +119,12 @@ func (t *SPITool) list() *ToolResult {
return SilentResult(fmt.Sprintf("Found %d SPI device(s):\n%s", len(devices), string(result)))
}
// Helper function for SPI operations (used by platform-specific implementations)
// parseSPIArgs extracts and validates common SPI parameters
func parseSPIArgs(args map[string]interface{}) (device string, speed uint32, mode uint8, bits uint8, errMsg string) {
//
//nolint:unused // Used by spi_linux.go
func parseSPIArgs(args map[string]any) (device string, speed uint32, mode uint8, bits uint8, errMsg string) {
dev, ok := args["device"].(string)
if !ok || dev == "" {
return "", 0, 0, 0, "device is required (e.g. \"2.0\" for /dev/spidev2.0)"
+8 -6
View File
@@ -66,10 +66,12 @@ func configureSPI(devPath string, mode uint8, bits uint8, speed uint32) (int, *T
}
// transfer performs a full-duplex SPI transfer
func (t *SPITool) transfer(args map[string]interface{}) *ToolResult {
func (t *SPITool) transfer(args map[string]any) *ToolResult {
confirm, _ := args["confirm"].(bool)
if !confirm {
return ErrorResult("transfer operations require confirm: true. Please confirm with the user before sending data to SPI devices.")
return ErrorResult(
"transfer operations require confirm: true. Please confirm with the user before sending data to SPI devices.",
)
}
dev, speed, mode, bits, errMsg := parseSPIArgs(args)
@@ -77,7 +79,7 @@ func (t *SPITool) transfer(args map[string]interface{}) *ToolResult {
return ErrorResult(errMsg)
}
dataRaw, ok := args["data"].([]interface{})
dataRaw, ok := args["data"].([]any)
if !ok || len(dataRaw) == 0 {
return ErrorResult("data is required for transfer (array of byte values 0-255)")
}
@@ -130,7 +132,7 @@ func (t *SPITool) transfer(args map[string]interface{}) *ToolResult {
intBytes[i] = int(b)
}
result, _ := json.MarshalIndent(map[string]interface{}{
result, _ := json.MarshalIndent(map[string]any{
"device": devPath,
"sent": len(txBuf),
"received": intBytes,
@@ -140,7 +142,7 @@ func (t *SPITool) transfer(args map[string]interface{}) *ToolResult {
}
// readDevice reads bytes from SPI by sending zeros (read-only, no confirm needed)
func (t *SPITool) readDevice(args map[string]interface{}) *ToolResult {
func (t *SPITool) readDevice(args map[string]any) *ToolResult {
dev, speed, mode, bits, errMsg := parseSPIArgs(args)
if errMsg != "" {
return ErrorResult(errMsg)
@@ -186,7 +188,7 @@ func (t *SPITool) readDevice(args map[string]interface{}) *ToolResult {
intBytes[i] = int(b)
}
result, _ := json.MarshalIndent(map[string]interface{}{
result, _ := json.MarshalIndent(map[string]any{
"device": devPath,
"bytes": intBytes,
"hex": hexBytes,
+2 -2
View File
@@ -3,11 +3,11 @@
package tools
// transfer is a stub for non-Linux platforms.
func (t *SPITool) transfer(args map[string]interface{}) *ToolResult {
func (t *SPITool) transfer(args map[string]any) *ToolResult {
return ErrorResult("SPI is only supported on Linux")
}
// readDevice is a stub for non-Linux platforms.
func (t *SPITool) readDevice(args map[string]interface{}) *ToolResult {
func (t *SPITool) readDevice(args map[string]any) *ToolResult {
return ErrorResult("SPI is only supported on Linux")
}
+28 -15
View File
@@ -38,7 +38,11 @@ type SubagentManager struct {
nextID int
}
func NewSubagentManager(provider providers.LLMProvider, defaultModel, workspace string, bus *bus.MessageBus) *SubagentManager {
func NewSubagentManager(
provider providers.LLMProvider,
defaultModel, workspace string,
bus *bus.MessageBus,
) *SubagentManager {
return &SubagentManager{
tasks: make(map[string]*SubagentTask),
provider: provider,
@@ -76,7 +80,11 @@ func (sm *SubagentManager) RegisterTool(tool Tool) {
sm.tools.Register(tool)
}
func (sm *SubagentManager) Spawn(ctx context.Context, task, label, agentID, originChannel, originChatID string, callback AsyncCallback) (string, error) {
func (sm *SubagentManager) Spawn(
ctx context.Context,
task, label, agentID, originChannel, originChatID string,
callback AsyncCallback,
) (string, error) {
sm.mu.Lock()
defer sm.mu.Unlock()
@@ -124,12 +132,12 @@ After completing the task, provide a clear summary of what was done.`
},
}
// Check if context is already cancelled before starting
// Check if context is already canceled before starting
select {
case <-ctx.Done():
sm.mu.Lock()
task.Status = "cancelled"
task.Result = "Task cancelled before execution"
task.Status = "canceled"
task.Result = "Task canceled before execution"
sm.mu.Unlock()
return
default:
@@ -177,10 +185,10 @@ After completing the task, provide a clear summary of what was done.`
if err != nil {
task.Status = "failed"
task.Result = fmt.Sprintf("Error: %v", err)
// Check if it was cancelled
// Check if it was canceled
if ctx.Err() != nil {
task.Status = "cancelled"
task.Result = "Task cancelled during execution"
task.Status = "canceled"
task.Result = "Task canceled during execution"
}
result = &ToolResult{
ForLLM: task.Result,
@@ -194,7 +202,12 @@ After completing the task, provide a clear summary of what was done.`
task.Status = "completed"
task.Result = loopResult.Content
result = &ToolResult{
ForLLM: fmt.Sprintf("Subagent '%s' completed (iterations: %d): %s", task.Label, loopResult.Iterations, loopResult.Content),
ForLLM: fmt.Sprintf(
"Subagent '%s' completed (iterations: %d): %s",
task.Label,
loopResult.Iterations,
loopResult.Content,
),
ForUser: loopResult.Content,
Silent: false,
IsError: false,
@@ -258,15 +271,15 @@ func (t *SubagentTool) Description() string {
return "Execute a subagent task synchronously and return the result. Use this for delegating specific tasks to an independent agent instance. Returns execution summary to user and full details to LLM."
}
func (t *SubagentTool) Parameters() map[string]interface{} {
return map[string]interface{}{
func (t *SubagentTool) Parameters() map[string]any {
return map[string]any{
"type": "object",
"properties": map[string]interface{}{
"task": map[string]interface{}{
"properties": map[string]any{
"task": map[string]any{
"type": "string",
"description": "The task for subagent to complete",
},
"label": map[string]interface{}{
"label": map[string]any{
"type": "string",
"description": "Optional short label for the task (for display)",
},
@@ -280,7 +293,7 @@ func (t *SubagentTool) SetContext(channel, chatID string) {
t.originChatID = chatID
}
func (t *SubagentTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
func (t *SubagentTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
task, ok := args["task"].(string)
if !ok {
return ErrorResult("task is required").WithError(fmt.Errorf("task parameter is required"))
+18 -12
View File
@@ -11,10 +11,16 @@ import (
// MockLLMProvider is a test implementation of LLMProvider
type MockLLMProvider struct {
lastOptions map[string]interface{}
lastOptions map[string]any
}
func (m *MockLLMProvider) Chat(ctx context.Context, messages []providers.Message, tools []providers.ToolDefinition, model string, options map[string]interface{}) (*providers.LLMResponse, error) {
func (m *MockLLMProvider) Chat(
ctx context.Context,
messages []providers.Message,
tools []providers.ToolDefinition,
model string,
options map[string]any,
) (*providers.LLMResponse, error) {
m.lastOptions = options
// Find the last user message to generate a response
for i := len(messages) - 1; i >= 0; i-- {
@@ -47,7 +53,7 @@ func TestSubagentManager_SetLLMOptions_AppliesToRunToolLoop(t *testing.T) {
tool.SetContext("cli", "direct")
ctx := context.Background()
args := map[string]interface{}{"task": "Do something"}
args := map[string]any{"task": "Do something"}
result := tool.Execute(ctx, args)
if result == nil || result.IsError {
@@ -108,13 +114,13 @@ func TestSubagentTool_Parameters(t *testing.T) {
}
// Check properties
props, ok := params["properties"].(map[string]interface{})
props, ok := params["properties"].(map[string]any)
if !ok {
t.Fatal("Properties should be a map")
}
// Verify task parameter
task, ok := props["task"].(map[string]interface{})
task, ok := props["task"].(map[string]any)
if !ok {
t.Fatal("Task parameter should exist")
}
@@ -123,7 +129,7 @@ func TestSubagentTool_Parameters(t *testing.T) {
}
// Verify label parameter
label, ok := props["label"].(map[string]interface{})
label, ok := props["label"].(map[string]any)
if !ok {
t.Fatal("Label parameter should exist")
}
@@ -163,7 +169,7 @@ func TestSubagentTool_Execute_Success(t *testing.T) {
tool.SetContext("telegram", "chat-123")
ctx := context.Background()
args := map[string]interface{}{
args := map[string]any{
"task": "Write a haiku about coding",
"label": "haiku-task",
}
@@ -218,7 +224,7 @@ func TestSubagentTool_Execute_NoLabel(t *testing.T) {
tool := NewSubagentTool(manager)
ctx := context.Background()
args := map[string]interface{}{
args := map[string]any{
"task": "Test task without label",
}
@@ -241,7 +247,7 @@ func TestSubagentTool_Execute_MissingTask(t *testing.T) {
tool := NewSubagentTool(manager)
ctx := context.Background()
args := map[string]interface{}{
args := map[string]any{
"label": "test",
}
@@ -268,7 +274,7 @@ func TestSubagentTool_Execute_NilManager(t *testing.T) {
tool := NewSubagentTool(nil)
ctx := context.Background()
args := map[string]interface{}{
args := map[string]any{
"task": "test task",
}
@@ -297,7 +303,7 @@ func TestSubagentTool_Execute_ContextPassing(t *testing.T) {
tool.SetContext(channel, chatID)
ctx := context.Background()
args := map[string]interface{}{
args := map[string]any{
"task": "Test context passing",
}
@@ -324,7 +330,7 @@ func TestSubagentTool_ForUserTruncation(t *testing.T) {
// Create a task that will generate long response
longTask := strings.Repeat("This is a very long task description. ", 100)
args := map[string]interface{}{
args := map[string]any{
"task": longTask,
"label": "long-test",
}
+6 -1
View File
@@ -33,7 +33,12 @@ type ToolLoopResult struct {
// RunToolLoop executes the LLM + tool call iteration loop.
// This is the core agent logic that can be reused by both main agent and subagents.
func RunToolLoop(ctx context.Context, config ToolLoopConfig, messages []providers.Message, channel, chatID string) (*ToolLoopResult, error) {
func RunToolLoop(
ctx context.Context,
config ToolLoopConfig,
messages []providers.Message,
channel, chatID string,
) (*ToolLoopResult, error) {
iteration := 0
var finalContent string
+15 -9
View File
@@ -10,11 +10,11 @@ type Message struct {
}
type ToolCall struct {
ID string `json:"id"`
Type string `json:"type"`
Function *FunctionCall `json:"function,omitempty"`
Name string `json:"name,omitempty"`
Arguments map[string]interface{} `json:"arguments,omitempty"`
ID string `json:"id"`
Type string `json:"type"`
Function *FunctionCall `json:"function,omitempty"`
Name string `json:"name,omitempty"`
Arguments map[string]any `json:"arguments,omitempty"`
}
type FunctionCall struct {
@@ -36,7 +36,13 @@ type UsageInfo struct {
}
type LLMProvider interface {
Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error)
Chat(
ctx context.Context,
messages []Message,
tools []ToolDefinition,
model string,
options map[string]any,
) (*LLMResponse, error)
GetDefaultModel() string
}
@@ -46,7 +52,7 @@ type ToolDefinition struct {
}
type ToolFunctionDefinition struct {
Name string `json:"name"`
Description string `json:"description"`
Parameters map[string]interface{} `json:"parameters"`
Name string `json:"name"`
Description string `json:"description"`
Parameters map[string]any `json:"parameters"`
}
+211 -40
View File
@@ -1,6 +1,7 @@
package tools
import (
"bytes"
"context"
"encoding/json"
"fmt"
@@ -16,12 +17,50 @@ const (
userAgent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36"
)
// createHTTPClient creates an HTTP client with optional proxy support
func createHTTPClient(proxyURL string, timeout time.Duration) (*http.Client, error) {
client := &http.Client{
Timeout: timeout,
Transport: &http.Transport{
MaxIdleConns: 10,
IdleConnTimeout: 30 * time.Second,
DisableCompression: false,
TLSHandshakeTimeout: 15 * time.Second,
},
}
if proxyURL != "" {
proxy, err := url.Parse(proxyURL)
if err != nil {
return nil, fmt.Errorf("invalid proxy URL: %w", err)
}
scheme := strings.ToLower(proxy.Scheme)
switch scheme {
case "http", "https", "socks5", "socks5h":
default:
return nil, fmt.Errorf(
"unsupported proxy scheme %q (supported: http, https, socks5, socks5h)",
proxy.Scheme,
)
}
if proxy.Host == "" {
return nil, fmt.Errorf("invalid proxy URL: missing host")
}
client.Transport.(*http.Transport).Proxy = http.ProxyURL(proxy)
} else {
client.Transport.(*http.Transport).Proxy = http.ProxyFromEnvironment
}
return client, nil
}
type SearchProvider interface {
Search(ctx context.Context, query string, count int) (string, error)
}
type BraveSearchProvider struct {
apiKey string
proxy string
}
func (p *BraveSearchProvider) Search(ctx context.Context, query string, count int) (string, error) {
@@ -36,7 +75,10 @@ func (p *BraveSearchProvider) Search(ctx context.Context, query string, count in
req.Header.Set("Accept", "application/json")
req.Header.Set("X-Subscription-Token", p.apiKey)
client := &http.Client{Timeout: 10 * time.Second}
client, err := createHTTPClient(p.proxy, 10*time.Second)
if err != nil {
return "", fmt.Errorf("failed to create HTTP client: %w", err)
}
resp, err := client.Do(req)
if err != nil {
return "", fmt.Errorf("request failed: %w", err)
@@ -84,7 +126,95 @@ func (p *BraveSearchProvider) Search(ctx context.Context, query string, count in
return strings.Join(lines, "\n"), nil
}
type DuckDuckGoSearchProvider struct{}
type TavilySearchProvider struct {
apiKey string
baseURL string
proxy string
}
func (p *TavilySearchProvider) Search(ctx context.Context, query string, count int) (string, error) {
searchURL := p.baseURL
if searchURL == "" {
searchURL = "https://api.tavily.com/search"
}
payload := map[string]any{
"api_key": p.apiKey,
"query": query,
"search_depth": "advanced",
"include_answer": false,
"include_images": false,
"include_raw_content": false,
"max_results": count,
}
bodyBytes, err := json.Marshal(payload)
if err != nil {
return "", fmt.Errorf("failed to marshal payload: %w", err)
}
req, err := http.NewRequestWithContext(ctx, "POST", searchURL, bytes.NewBuffer(bodyBytes))
if err != nil {
return "", fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", userAgent)
client, err := createHTTPClient(p.proxy, 10*time.Second)
if err != nil {
return "", fmt.Errorf("failed to create HTTP client: %w", err)
}
resp, err := client.Do(req)
if err != nil {
return "", fmt.Errorf("request failed: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("failed to read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("tavily api error (status %d): %s", resp.StatusCode, string(body))
}
var searchResp struct {
Results []struct {
Title string `json:"title"`
URL string `json:"url"`
Content string `json:"content"`
} `json:"results"`
}
if err := json.Unmarshal(body, &searchResp); err != nil {
return "", fmt.Errorf("failed to parse response: %w", err)
}
results := searchResp.Results
if len(results) == 0 {
return fmt.Sprintf("No results for: %s", query), nil
}
var lines []string
lines = append(lines, fmt.Sprintf("Results for: %s (via Tavily)", query))
for i, item := range results {
if i >= count {
break
}
lines = append(lines, fmt.Sprintf("%d. %s\n %s", i+1, item.Title, item.URL))
if item.Content != "" {
lines = append(lines, fmt.Sprintf(" %s", item.Content))
}
}
return strings.Join(lines, "\n"), nil
}
type DuckDuckGoSearchProvider struct {
proxy string
}
func (p *DuckDuckGoSearchProvider) Search(ctx context.Context, query string, count int) (string, error) {
searchURL := fmt.Sprintf("https://html.duckduckgo.com/html/?q=%s", url.QueryEscape(query))
@@ -96,7 +226,10 @@ func (p *DuckDuckGoSearchProvider) Search(ctx context.Context, query string, cou
req.Header.Set("User-Agent", userAgent)
client := &http.Client{Timeout: 10 * time.Second}
client, err := createHTTPClient(p.proxy, 10*time.Second)
if err != nil {
return "", fmt.Errorf("failed to create HTTP client: %w", err)
}
resp, err := client.Do(req)
if err != nil {
return "", fmt.Errorf("request failed: %w", err)
@@ -178,16 +311,23 @@ func stripTags(content string) string {
type PerplexitySearchProvider struct {
apiKey string
proxy string
}
func (p *PerplexitySearchProvider) Search(ctx context.Context, query string, count int) (string, error) {
searchURL := "https://api.perplexity.ai/chat/completions"
payload := map[string]interface{}{
payload := map[string]any{
"model": "sonar",
"messages": []map[string]string{
{"role": "system", "content": "You are a search assistant. Provide concise search results with titles, URLs, and brief descriptions in the following format:\n1. Title\n URL\n Description\n\nDo not add extra commentary."},
{"role": "user", "content": fmt.Sprintf("Search for: %s. Provide up to %d relevant results.", query, count)},
{
"role": "system",
"content": "You are a search assistant. Provide concise search results with titles, URLs, and brief descriptions in the following format:\n1. Title\n URL\n Description\n\nDo not add extra commentary.",
},
{
"role": "user",
"content": fmt.Sprintf("Search for: %s. Provide up to %d relevant results.", query, count),
},
},
"max_tokens": 1000,
}
@@ -206,7 +346,10 @@ func (p *PerplexitySearchProvider) Search(ctx context.Context, query string, cou
req.Header.Set("Authorization", "Bearer "+p.apiKey)
req.Header.Set("User-Agent", userAgent)
client := &http.Client{Timeout: 30 * time.Second}
client, err := createHTTPClient(p.proxy, 30*time.Second)
if err != nil {
return "", fmt.Errorf("failed to create HTTP client: %w", err)
}
resp, err := client.Do(req)
if err != nil {
return "", fmt.Errorf("request failed: %w", err)
@@ -312,6 +455,10 @@ type WebSearchToolOptions struct {
BraveAPIKey string
BraveMaxResults int
BraveEnabled bool
TavilyAPIKey string
TavilyBaseURL string
TavilyMaxResults int
TavilyEnabled bool
DuckDuckGoMaxResults int
DuckDuckGoEnabled bool
PerplexityAPIKey string
@@ -320,20 +467,21 @@ type WebSearchToolOptions struct {
SearXNGBaseURL string
SearXNGMaxResults int
SearXNGEnabled bool
Proxy string
}
func NewWebSearchTool(opts WebSearchToolOptions) *WebSearchTool {
var provider SearchProvider
maxResults := 5
// Priority: Perplexity > Brave > SearXNG > DuckDuckGo
// Priority: Perplexity > Brave > SearXNG > Tavily > DuckDuckGo
if opts.PerplexityEnabled && opts.PerplexityAPIKey != "" {
provider = &PerplexitySearchProvider{apiKey: opts.PerplexityAPIKey}
provider = &PerplexitySearchProvider{apiKey: opts.PerplexityAPIKey, proxy: opts.Proxy}
if opts.PerplexityMaxResults > 0 {
maxResults = opts.PerplexityMaxResults
}
} else if opts.BraveEnabled && opts.BraveAPIKey != "" {
provider = &BraveSearchProvider{apiKey: opts.BraveAPIKey}
provider = &BraveSearchProvider{apiKey: opts.BraveAPIKey, proxy: opts.Proxy}
if opts.BraveMaxResults > 0 {
maxResults = opts.BraveMaxResults
}
@@ -342,8 +490,17 @@ func NewWebSearchTool(opts WebSearchToolOptions) *WebSearchTool {
if opts.SearXNGMaxResults > 0 {
maxResults = opts.SearXNGMaxResults
}
} else if opts.TavilyEnabled && opts.TavilyAPIKey != "" {
provider = &TavilySearchProvider{
apiKey: opts.TavilyAPIKey,
baseURL: opts.TavilyBaseURL,
proxy: opts.Proxy,
}
if opts.TavilyMaxResults > 0 {
maxResults = opts.TavilyMaxResults
}
} else if opts.DuckDuckGoEnabled {
provider = &DuckDuckGoSearchProvider{}
provider = &DuckDuckGoSearchProvider{proxy: opts.Proxy}
if opts.DuckDuckGoMaxResults > 0 {
maxResults = opts.DuckDuckGoMaxResults
}
@@ -365,15 +522,15 @@ func (t *WebSearchTool) Description() string {
return "Search the web for current information. Returns titles, URLs, and snippets from search results."
}
func (t *WebSearchTool) Parameters() map[string]interface{} {
return map[string]interface{}{
func (t *WebSearchTool) Parameters() map[string]any {
return map[string]any{
"type": "object",
"properties": map[string]interface{}{
"query": map[string]interface{}{
"properties": map[string]any{
"query": map[string]any{
"type": "string",
"description": "Search query",
},
"count": map[string]interface{}{
"count": map[string]any{
"type": "integer",
"description": "Number of results (1-10)",
"minimum": 1.0,
@@ -384,7 +541,7 @@ func (t *WebSearchTool) Parameters() map[string]interface{} {
}
}
func (t *WebSearchTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
func (t *WebSearchTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
query, ok := args["query"].(string)
if !ok {
return ErrorResult("query is required")
@@ -410,6 +567,7 @@ func (t *WebSearchTool) Execute(ctx context.Context, args map[string]interface{}
type WebFetchTool struct {
maxChars int
proxy string
}
func NewWebFetchTool(maxChars int) *WebFetchTool {
@@ -421,6 +579,16 @@ func NewWebFetchTool(maxChars int) *WebFetchTool {
}
}
func NewWebFetchToolWithProxy(maxChars int, proxy string) *WebFetchTool {
if maxChars <= 0 {
maxChars = 50000
}
return &WebFetchTool{
maxChars: maxChars,
proxy: proxy,
}
}
func (t *WebFetchTool) Name() string {
return "web_fetch"
}
@@ -429,15 +597,15 @@ func (t *WebFetchTool) Description() string {
return "Fetch a URL and extract readable content (HTML to text). Use this to get weather info, news, articles, or any web content."
}
func (t *WebFetchTool) Parameters() map[string]interface{} {
return map[string]interface{}{
func (t *WebFetchTool) Parameters() map[string]any {
return map[string]any{
"type": "object",
"properties": map[string]interface{}{
"url": map[string]interface{}{
"properties": map[string]any{
"url": map[string]any{
"type": "string",
"description": "URL to fetch",
},
"maxChars": map[string]interface{}{
"maxChars": map[string]any{
"type": "integer",
"description": "Maximum characters to extract",
"minimum": 100.0,
@@ -447,7 +615,7 @@ func (t *WebFetchTool) Parameters() map[string]interface{} {
}
}
func (t *WebFetchTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
func (t *WebFetchTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
urlStr, ok := args["url"].(string)
if !ok {
return ErrorResult("url is required")
@@ -480,20 +648,17 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]interface{})
req.Header.Set("User-Agent", userAgent)
client := &http.Client{
Timeout: 60 * time.Second,
Transport: &http.Transport{
MaxIdleConns: 10,
IdleConnTimeout: 30 * time.Second,
DisableCompression: false,
TLSHandshakeTimeout: 15 * time.Second,
},
CheckRedirect: func(req *http.Request, via []*http.Request) error {
if len(via) >= 5 {
return fmt.Errorf("stopped after 5 redirects")
}
return nil
},
client, err := createHTTPClient(t.proxy, 60*time.Second)
if err != nil {
return ErrorResult(fmt.Sprintf("failed to create HTTP client: %v", err))
}
// Configure redirect handling
client.CheckRedirect = func(req *http.Request, via []*http.Request) error {
if len(via) >= 5 {
return fmt.Errorf("stopped after 5 redirects")
}
return nil
}
resp, err := client.Do(req)
@@ -512,7 +677,7 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]interface{})
var text, extractor string
if strings.Contains(contentType, "application/json") {
var jsonData interface{}
var jsonData any
if err := json.Unmarshal(body, &jsonData); err == nil {
formatted, _ := json.MarshalIndent(jsonData, "", " ")
text = string(formatted)
@@ -535,7 +700,7 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]interface{})
text = text[:maxChars]
}
result := map[string]interface{}{
result := map[string]any{
"url": urlStr,
"status": resp.StatusCode,
"extractor": extractor,
@@ -547,7 +712,13 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]interface{})
resultJSON, _ := json.MarshalIndent(result, "", " ")
return &ToolResult{
ForLLM: fmt.Sprintf("Fetched %d bytes from %s (extractor: %s, truncated: %v)", len(text), urlStr, extractor, truncated),
ForLLM: fmt.Sprintf(
"Fetched %d bytes from %s (extractor: %s, truncated: %v)",
len(text),
urlStr,
extractor,
truncated,
),
ForUser: string(resultJSON),
}
}
+256 -12
View File
@@ -7,6 +7,7 @@ import (
"net/http/httptest"
"strings"
"testing"
"time"
)
// TestWebTool_WebFetch_Success verifies successful URL fetching
@@ -20,7 +21,7 @@ func TestWebTool_WebFetch_Success(t *testing.T) {
tool := NewWebFetchTool(50000)
ctx := context.Background()
args := map[string]interface{}{
args := map[string]any{
"url": server.URL,
}
@@ -56,7 +57,7 @@ func TestWebTool_WebFetch_JSON(t *testing.T) {
tool := NewWebFetchTool(50000)
ctx := context.Background()
args := map[string]interface{}{
args := map[string]any{
"url": server.URL,
}
@@ -77,7 +78,7 @@ func TestWebTool_WebFetch_JSON(t *testing.T) {
func TestWebTool_WebFetch_InvalidURL(t *testing.T) {
tool := NewWebFetchTool(50000)
ctx := context.Background()
args := map[string]interface{}{
args := map[string]any{
"url": "not-a-valid-url",
}
@@ -98,7 +99,7 @@ func TestWebTool_WebFetch_InvalidURL(t *testing.T) {
func TestWebTool_WebFetch_UnsupportedScheme(t *testing.T) {
tool := NewWebFetchTool(50000)
ctx := context.Background()
args := map[string]interface{}{
args := map[string]any{
"url": "ftp://example.com/file.txt",
}
@@ -119,7 +120,7 @@ func TestWebTool_WebFetch_UnsupportedScheme(t *testing.T) {
func TestWebTool_WebFetch_MissingURL(t *testing.T) {
tool := NewWebFetchTool(50000)
ctx := context.Background()
args := map[string]interface{}{}
args := map[string]any{}
result := tool.Execute(ctx, args)
@@ -147,7 +148,7 @@ func TestWebTool_WebFetch_Truncation(t *testing.T) {
tool := NewWebFetchTool(1000) // Limit to 1000 chars
ctx := context.Background()
args := map[string]interface{}{
args := map[string]any{
"url": server.URL,
}
@@ -159,7 +160,7 @@ func TestWebTool_WebFetch_Truncation(t *testing.T) {
}
// ForUser should contain truncated content (not the full 20000 chars)
resultMap := make(map[string]interface{})
resultMap := make(map[string]any)
json.Unmarshal([]byte(result.ForUser), &resultMap)
if text, ok := resultMap["text"].(string); ok {
if len(text) > 1100 { // Allow some margin
@@ -191,7 +192,7 @@ func TestWebTool_WebSearch_NoApiKey(t *testing.T) {
func TestWebTool_WebSearch_MissingQuery(t *testing.T) {
tool := NewWebSearchTool(WebSearchToolOptions{BraveEnabled: true, BraveAPIKey: "test-key", BraveMaxResults: 5})
ctx := context.Background()
args := map[string]interface{}{}
args := map[string]any{}
result := tool.Execute(ctx, args)
@@ -206,13 +207,17 @@ func TestWebTool_WebFetch_HTMLExtraction(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/html")
w.WriteHeader(http.StatusOK)
w.Write([]byte(`<html><body><script>alert('test');</script><style>body{color:red;}</style><h1>Title</h1><p>Content</p></body></html>`))
w.Write(
[]byte(
`<html><body><script>alert('test');</script><style>body{color:red;}</style><h1>Title</h1><p>Content</p></body></html>`,
),
)
}))
defer server.Close()
tool := NewWebFetchTool(50000)
ctx := context.Background()
args := map[string]interface{}{
args := map[string]any{
"url": server.URL,
}
@@ -251,7 +256,8 @@ func TestWebFetchTool_extractText(t *testing.T) {
if len(lines) < 2 {
t.Errorf("Expected multiple lines, got %d: %q", len(lines), got)
}
if !strings.Contains(got, "Title") || !strings.Contains(got, "Paragraph 1") || !strings.Contains(got, "Paragraph 2") {
if !strings.Contains(got, "Title") || !strings.Contains(got, "Paragraph 1") ||
!strings.Contains(got, "Paragraph 2") {
t.Errorf("Missing expected text: %q", got)
}
},
@@ -312,7 +318,7 @@ func TestWebFetchTool_extractText(t *testing.T) {
func TestWebTool_WebFetch_MissingDomain(t *testing.T) {
tool := NewWebFetchTool(50000)
ctx := context.Background()
args := map[string]interface{}{
args := map[string]any{
"url": "https://",
}
@@ -328,3 +334,241 @@ func TestWebTool_WebFetch_MissingDomain(t *testing.T) {
t.Errorf("Expected domain error message, got ForLLM: %s", result.ForLLM)
}
}
func TestCreateHTTPClient_ProxyConfigured(t *testing.T) {
client, err := createHTTPClient("http://127.0.0.1:7890", 12*time.Second)
if err != nil {
t.Fatalf("createHTTPClient() error: %v", err)
}
if client.Timeout != 12*time.Second {
t.Fatalf("client.Timeout = %v, want %v", client.Timeout, 12*time.Second)
}
tr, ok := client.Transport.(*http.Transport)
if !ok {
t.Fatalf("client.Transport type = %T, want *http.Transport", client.Transport)
}
if tr.Proxy == nil {
t.Fatal("transport.Proxy is nil, want non-nil")
}
req, err := http.NewRequest("GET", "https://example.com", nil)
if err != nil {
t.Fatalf("http.NewRequest() error: %v", err)
}
proxyURL, err := tr.Proxy(req)
if err != nil {
t.Fatalf("transport.Proxy(req) error: %v", err)
}
if proxyURL == nil || proxyURL.String() != "http://127.0.0.1:7890" {
t.Fatalf("proxy URL = %v, want %q", proxyURL, "http://127.0.0.1:7890")
}
}
func TestCreateHTTPClient_InvalidProxy(t *testing.T) {
_, err := createHTTPClient("://bad-proxy", 10*time.Second)
if err == nil {
t.Fatal("createHTTPClient() expected error for invalid proxy URL, got nil")
}
}
func TestCreateHTTPClient_Socks5ProxyConfigured(t *testing.T) {
client, err := createHTTPClient("socks5://127.0.0.1:1080", 8*time.Second)
if err != nil {
t.Fatalf("createHTTPClient() error: %v", err)
}
tr, ok := client.Transport.(*http.Transport)
if !ok {
t.Fatalf("client.Transport type = %T, want *http.Transport", client.Transport)
}
req, err := http.NewRequest("GET", "https://example.com", nil)
if err != nil {
t.Fatalf("http.NewRequest() error: %v", err)
}
proxyURL, err := tr.Proxy(req)
if err != nil {
t.Fatalf("transport.Proxy(req) error: %v", err)
}
if proxyURL == nil || proxyURL.String() != "socks5://127.0.0.1:1080" {
t.Fatalf("proxy URL = %v, want %q", proxyURL, "socks5://127.0.0.1:1080")
}
}
func TestCreateHTTPClient_UnsupportedProxyScheme(t *testing.T) {
_, err := createHTTPClient("ftp://127.0.0.1:21", 10*time.Second)
if err == nil {
t.Fatal("createHTTPClient() expected error for unsupported scheme, got nil")
}
if !strings.Contains(err.Error(), "unsupported proxy scheme") {
t.Fatalf("error = %q, want to contain %q", err.Error(), "unsupported proxy scheme")
}
}
func TestCreateHTTPClient_ProxyFromEnvironmentWhenConfigEmpty(t *testing.T) {
t.Setenv("HTTP_PROXY", "http://127.0.0.1:8888")
t.Setenv("http_proxy", "http://127.0.0.1:8888")
t.Setenv("HTTPS_PROXY", "http://127.0.0.1:8888")
t.Setenv("https_proxy", "http://127.0.0.1:8888")
t.Setenv("ALL_PROXY", "")
t.Setenv("all_proxy", "")
t.Setenv("NO_PROXY", "")
t.Setenv("no_proxy", "")
client, err := createHTTPClient("", 10*time.Second)
if err != nil {
t.Fatalf("createHTTPClient() error: %v", err)
}
tr, ok := client.Transport.(*http.Transport)
if !ok {
t.Fatalf("client.Transport type = %T, want *http.Transport", client.Transport)
}
if tr.Proxy == nil {
t.Fatal("transport.Proxy is nil, want proxy function from environment")
}
req, err := http.NewRequest("GET", "https://example.com", nil)
if err != nil {
t.Fatalf("http.NewRequest() error: %v", err)
}
if _, err := tr.Proxy(req); err != nil {
t.Fatalf("transport.Proxy(req) error: %v", err)
}
}
func TestNewWebFetchToolWithProxy(t *testing.T) {
tool := NewWebFetchToolWithProxy(1024, "http://127.0.0.1:7890")
if tool.maxChars != 1024 {
t.Fatalf("maxChars = %d, want %d", tool.maxChars, 1024)
}
if tool.proxy != "http://127.0.0.1:7890" {
t.Fatalf("proxy = %q, want %q", tool.proxy, "http://127.0.0.1:7890")
}
tool = NewWebFetchToolWithProxy(0, "http://127.0.0.1:7890")
if tool.maxChars != 50000 {
t.Fatalf("default maxChars = %d, want %d", tool.maxChars, 50000)
}
}
func TestNewWebSearchTool_PropagatesProxy(t *testing.T) {
t.Run("perplexity", func(t *testing.T) {
tool := NewWebSearchTool(WebSearchToolOptions{
PerplexityEnabled: true,
PerplexityAPIKey: "k",
PerplexityMaxResults: 3,
Proxy: "http://127.0.0.1:7890",
})
p, ok := tool.provider.(*PerplexitySearchProvider)
if !ok {
t.Fatalf("provider type = %T, want *PerplexitySearchProvider", tool.provider)
}
if p.proxy != "http://127.0.0.1:7890" {
t.Fatalf("provider proxy = %q, want %q", p.proxy, "http://127.0.0.1:7890")
}
})
t.Run("brave", func(t *testing.T) {
tool := NewWebSearchTool(WebSearchToolOptions{
BraveEnabled: true,
BraveAPIKey: "k",
BraveMaxResults: 3,
Proxy: "http://127.0.0.1:7890",
})
p, ok := tool.provider.(*BraveSearchProvider)
if !ok {
t.Fatalf("provider type = %T, want *BraveSearchProvider", tool.provider)
}
if p.proxy != "http://127.0.0.1:7890" {
t.Fatalf("provider proxy = %q, want %q", p.proxy, "http://127.0.0.1:7890")
}
})
t.Run("duckduckgo", func(t *testing.T) {
tool := NewWebSearchTool(WebSearchToolOptions{
DuckDuckGoEnabled: true,
DuckDuckGoMaxResults: 3,
Proxy: "http://127.0.0.1:7890",
})
p, ok := tool.provider.(*DuckDuckGoSearchProvider)
if !ok {
t.Fatalf("provider type = %T, want *DuckDuckGoSearchProvider", tool.provider)
}
if p.proxy != "http://127.0.0.1:7890" {
t.Fatalf("provider proxy = %q, want %q", p.proxy, "http://127.0.0.1:7890")
}
})
}
// TestWebTool_TavilySearch_Success verifies successful Tavily search
func TestWebTool_TavilySearch_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
t.Errorf("Expected POST request, got %s", r.Method)
}
if r.Header.Get("Content-Type") != "application/json" {
t.Errorf("Expected Content-Type application/json, got %s", r.Header.Get("Content-Type"))
}
// Verify payload
var payload map[string]any
json.NewDecoder(r.Body).Decode(&payload)
if payload["api_key"] != "test-key" {
t.Errorf("Expected api_key test-key, got %v", payload["api_key"])
}
if payload["query"] != "test query" {
t.Errorf("Expected query 'test query', got %v", payload["query"])
}
// Return mock response
response := map[string]any{
"results": []map[string]any{
{
"title": "Test Result 1",
"url": "https://example.com/1",
"content": "Content for result 1",
},
{
"title": "Test Result 2",
"url": "https://example.com/2",
"content": "Content for result 2",
},
},
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(response)
}))
defer server.Close()
tool := NewWebSearchTool(WebSearchToolOptions{
TavilyEnabled: true,
TavilyAPIKey: "test-key",
TavilyBaseURL: server.URL,
TavilyMaxResults: 5,
})
ctx := context.Background()
args := map[string]any{
"query": "test query",
}
result := tool.Execute(ctx, args)
// Success should not be an error
if result.IsError {
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
}
// ForUser should contain result titles and URLs
if !strings.Contains(result.ForUser, "Test Result 1") ||
!strings.Contains(result.ForUser, "https://example.com/1") {
t.Errorf("Expected results in output, got: %s", result.ForUser)
}
// Should mention via Tavily
if !strings.Contains(result.ForUser, "via Tavily") {
t.Errorf("Expected 'via Tavily' in output, got: %s", result.ForUser)
}
}