mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
Merge remote-tracking branch 'origin/main' into feat/searxng
This commit is contained in:
+5
-5
@@ -6,8 +6,8 @@ import "context"
|
||||
type Tool interface {
|
||||
Name() string
|
||||
Description() string
|
||||
Parameters() map[string]interface{}
|
||||
Execute(ctx context.Context, args map[string]interface{}) *ToolResult
|
||||
Parameters() map[string]any
|
||||
Execute(ctx context.Context, args map[string]any) *ToolResult
|
||||
}
|
||||
|
||||
// ContextualTool is an optional interface that tools can implement
|
||||
@@ -69,10 +69,10 @@ type AsyncTool interface {
|
||||
SetCallback(cb AsyncCallback)
|
||||
}
|
||||
|
||||
func ToolToSchema(tool Tool) map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
func ToolToSchema(tool Tool) map[string]any {
|
||||
return map[string]any{
|
||||
"type": "function",
|
||||
"function": map[string]interface{}{
|
||||
"function": map[string]any{
|
||||
"name": tool.Name(),
|
||||
"description": tool.Description(),
|
||||
"parameters": tool.Parameters(),
|
||||
|
||||
+20
-18
@@ -30,7 +30,10 @@ type CronTool struct {
|
||||
|
||||
// NewCronTool creates a new CronTool
|
||||
// execTimeout: 0 means no timeout, >0 sets the timeout duration
|
||||
func NewCronTool(cronService *cron.CronService, executor JobExecutor, msgBus *bus.MessageBus, workspace string, restrict bool, execTimeout time.Duration, config *config.Config) *CronTool {
|
||||
func NewCronTool(
|
||||
cronService *cron.CronService, executor JobExecutor, msgBus *bus.MessageBus, workspace string, restrict bool,
|
||||
execTimeout time.Duration, config *config.Config,
|
||||
) *CronTool {
|
||||
execTool := NewExecToolWithConfig(workspace, restrict, config)
|
||||
execTool.SetTimeout(execTimeout)
|
||||
return &CronTool{
|
||||
@@ -52,40 +55,40 @@ func (t *CronTool) Description() string {
|
||||
}
|
||||
|
||||
// Parameters returns the tool parameters schema
|
||||
func (t *CronTool) Parameters() map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
func (t *CronTool) Parameters() map[string]any {
|
||||
return map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"action": map[string]interface{}{
|
||||
"properties": map[string]any{
|
||||
"action": map[string]any{
|
||||
"type": "string",
|
||||
"enum": []string{"add", "list", "remove", "enable", "disable"},
|
||||
"description": "Action to perform. Use 'add' when user wants to schedule a reminder or task.",
|
||||
},
|
||||
"message": map[string]interface{}{
|
||||
"message": map[string]any{
|
||||
"type": "string",
|
||||
"description": "The reminder/task message to display when triggered. If 'command' is used, this describes what the command does.",
|
||||
},
|
||||
"command": map[string]interface{}{
|
||||
"command": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Optional: Shell command to execute directly (e.g., 'df -h'). If set, the agent will run this command and report output instead of just showing the message. 'deliver' will be forced to false for commands.",
|
||||
},
|
||||
"at_seconds": map[string]interface{}{
|
||||
"at_seconds": map[string]any{
|
||||
"type": "integer",
|
||||
"description": "One-time reminder: seconds from now when to trigger (e.g., 600 for 10 minutes later). Use this for one-time reminders like 'remind me in 10 minutes'.",
|
||||
},
|
||||
"every_seconds": map[string]interface{}{
|
||||
"every_seconds": map[string]any{
|
||||
"type": "integer",
|
||||
"description": "Recurring interval in seconds (e.g., 3600 for every hour). Use this ONLY for recurring tasks like 'every 2 hours' or 'daily reminder'.",
|
||||
},
|
||||
"cron_expr": map[string]interface{}{
|
||||
"cron_expr": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Cron expression for complex recurring schedules (e.g., '0 9 * * *' for daily at 9am). Use this for complex recurring schedules.",
|
||||
},
|
||||
"job_id": map[string]interface{}{
|
||||
"job_id": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Job ID (for remove/enable/disable)",
|
||||
},
|
||||
"deliver": map[string]interface{}{
|
||||
"deliver": map[string]any{
|
||||
"type": "boolean",
|
||||
"description": "If true, send message directly to channel. If false, let agent process message (for complex tasks). Default: true",
|
||||
},
|
||||
@@ -103,7 +106,7 @@ func (t *CronTool) SetContext(channel, chatID string) {
|
||||
}
|
||||
|
||||
// Execute runs the tool with the given arguments
|
||||
func (t *CronTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
|
||||
func (t *CronTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
|
||||
action, ok := args["action"].(string)
|
||||
if !ok {
|
||||
return ErrorResult("action is required")
|
||||
@@ -125,7 +128,7 @@ func (t *CronTool) Execute(ctx context.Context, args map[string]interface{}) *To
|
||||
}
|
||||
}
|
||||
|
||||
func (t *CronTool) addJob(args map[string]interface{}) *ToolResult {
|
||||
func (t *CronTool) addJob(args map[string]any) *ToolResult {
|
||||
t.mu.RLock()
|
||||
channel := t.channel
|
||||
chatID := t.chatID
|
||||
@@ -233,7 +236,7 @@ func (t *CronTool) listJobs() *ToolResult {
|
||||
return SilentResult(result)
|
||||
}
|
||||
|
||||
func (t *CronTool) removeJob(args map[string]interface{}) *ToolResult {
|
||||
func (t *CronTool) removeJob(args map[string]any) *ToolResult {
|
||||
jobID, ok := args["job_id"].(string)
|
||||
if !ok || jobID == "" {
|
||||
return ErrorResult("job_id is required for remove")
|
||||
@@ -245,7 +248,7 @@ func (t *CronTool) removeJob(args map[string]interface{}) *ToolResult {
|
||||
return ErrorResult(fmt.Sprintf("Job %s not found", jobID))
|
||||
}
|
||||
|
||||
func (t *CronTool) enableJob(args map[string]interface{}, enable bool) *ToolResult {
|
||||
func (t *CronTool) enableJob(args map[string]any, enable bool) *ToolResult {
|
||||
jobID, ok := args["job_id"].(string)
|
||||
if !ok || jobID == "" {
|
||||
return ErrorResult("job_id is required for enable/disable")
|
||||
@@ -279,7 +282,7 @@ func (t *CronTool) ExecuteJob(ctx context.Context, job *cron.CronJob) string {
|
||||
|
||||
// Execute command if present
|
||||
if job.Payload.Command != "" {
|
||||
args := map[string]interface{}{
|
||||
args := map[string]any{
|
||||
"command": job.Payload.Command,
|
||||
}
|
||||
|
||||
@@ -320,7 +323,6 @@ func (t *CronTool) ExecuteJob(ctx context.Context, job *cron.CronJob) string {
|
||||
channel,
|
||||
chatID,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Sprintf("Error: %v", err)
|
||||
}
|
||||
|
||||
+77
-65
@@ -2,24 +2,27 @@ package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"io/fs"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// EditFileTool edits a file by replacing old_text with new_text.
|
||||
// The old_text must exist exactly in the file.
|
||||
type EditFileTool struct {
|
||||
allowedDir string
|
||||
restrict bool
|
||||
fs fileSystem
|
||||
}
|
||||
|
||||
// NewEditFileTool creates a new EditFileTool with optional directory restriction.
|
||||
func NewEditFileTool(allowedDir string, restrict bool) *EditFileTool {
|
||||
return &EditFileTool{
|
||||
allowedDir: allowedDir,
|
||||
restrict: restrict,
|
||||
func NewEditFileTool(workspace string, restrict bool) *EditFileTool {
|
||||
var fs fileSystem
|
||||
if restrict {
|
||||
fs = &sandboxFs{workspace: workspace}
|
||||
} else {
|
||||
fs = &hostFs{}
|
||||
}
|
||||
return &EditFileTool{fs: fs}
|
||||
}
|
||||
|
||||
func (t *EditFileTool) Name() string {
|
||||
@@ -30,19 +33,19 @@ func (t *EditFileTool) Description() string {
|
||||
return "Edit a file by replacing old_text with new_text. The old_text must exist exactly in the file."
|
||||
}
|
||||
|
||||
func (t *EditFileTool) Parameters() map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
func (t *EditFileTool) Parameters() map[string]any {
|
||||
return map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"path": map[string]interface{}{
|
||||
"properties": map[string]any{
|
||||
"path": map[string]any{
|
||||
"type": "string",
|
||||
"description": "The file path to edit",
|
||||
},
|
||||
"old_text": map[string]interface{}{
|
||||
"old_text": map[string]any{
|
||||
"type": "string",
|
||||
"description": "The exact text to find and replace",
|
||||
},
|
||||
"new_text": map[string]interface{}{
|
||||
"new_text": map[string]any{
|
||||
"type": "string",
|
||||
"description": "The text to replace with",
|
||||
},
|
||||
@@ -51,7 +54,7 @@ func (t *EditFileTool) Parameters() map[string]interface{} {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *EditFileTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
|
||||
func (t *EditFileTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
|
||||
path, ok := args["path"].(string)
|
||||
if !ok {
|
||||
return ErrorResult("path is required")
|
||||
@@ -67,47 +70,24 @@ func (t *EditFileTool) Execute(ctx context.Context, args map[string]interface{})
|
||||
return ErrorResult("new_text is required")
|
||||
}
|
||||
|
||||
resolvedPath, err := validatePath(path, t.allowedDir, t.restrict)
|
||||
if err != nil {
|
||||
if err := editFile(t.fs, path, oldText, newText); err != nil {
|
||||
return ErrorResult(err.Error())
|
||||
}
|
||||
|
||||
if _, err := os.Stat(resolvedPath); os.IsNotExist(err) {
|
||||
return ErrorResult(fmt.Sprintf("file not found: %s", path))
|
||||
}
|
||||
|
||||
content, err := os.ReadFile(resolvedPath)
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("failed to read file: %v", err))
|
||||
}
|
||||
|
||||
contentStr := string(content)
|
||||
|
||||
if !strings.Contains(contentStr, oldText) {
|
||||
return ErrorResult("old_text not found in file. Make sure it matches exactly")
|
||||
}
|
||||
|
||||
count := strings.Count(contentStr, oldText)
|
||||
if count > 1 {
|
||||
return ErrorResult(fmt.Sprintf("old_text appears %d times. Please provide more context to make it unique", count))
|
||||
}
|
||||
|
||||
newContent := strings.Replace(contentStr, oldText, newText, 1)
|
||||
|
||||
if err := os.WriteFile(resolvedPath, []byte(newContent), 0644); err != nil {
|
||||
return ErrorResult(fmt.Sprintf("failed to write file: %v", err))
|
||||
}
|
||||
|
||||
return SilentResult(fmt.Sprintf("File edited: %s", path))
|
||||
}
|
||||
|
||||
type AppendFileTool struct {
|
||||
workspace string
|
||||
restrict bool
|
||||
fs fileSystem
|
||||
}
|
||||
|
||||
func NewAppendFileTool(workspace string, restrict bool) *AppendFileTool {
|
||||
return &AppendFileTool{workspace: workspace, restrict: restrict}
|
||||
var fs fileSystem
|
||||
if restrict {
|
||||
fs = &sandboxFs{workspace: workspace}
|
||||
} else {
|
||||
fs = &hostFs{}
|
||||
}
|
||||
return &AppendFileTool{fs: fs}
|
||||
}
|
||||
|
||||
func (t *AppendFileTool) Name() string {
|
||||
@@ -118,15 +98,15 @@ func (t *AppendFileTool) Description() string {
|
||||
return "Append content to the end of a file"
|
||||
}
|
||||
|
||||
func (t *AppendFileTool) Parameters() map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
func (t *AppendFileTool) Parameters() map[string]any {
|
||||
return map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"path": map[string]interface{}{
|
||||
"properties": map[string]any{
|
||||
"path": map[string]any{
|
||||
"type": "string",
|
||||
"description": "The file path to append to",
|
||||
},
|
||||
"content": map[string]interface{}{
|
||||
"content": map[string]any{
|
||||
"type": "string",
|
||||
"description": "The content to append",
|
||||
},
|
||||
@@ -135,7 +115,7 @@ func (t *AppendFileTool) Parameters() map[string]interface{} {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *AppendFileTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
|
||||
func (t *AppendFileTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
|
||||
path, ok := args["path"].(string)
|
||||
if !ok {
|
||||
return ErrorResult("path is required")
|
||||
@@ -146,20 +126,52 @@ func (t *AppendFileTool) Execute(ctx context.Context, args map[string]interface{
|
||||
return ErrorResult("content is required")
|
||||
}
|
||||
|
||||
resolvedPath, err := validatePath(path, t.workspace, t.restrict)
|
||||
if err != nil {
|
||||
if err := appendFile(t.fs, path, content); err != nil {
|
||||
return ErrorResult(err.Error())
|
||||
}
|
||||
|
||||
f, err := os.OpenFile(resolvedPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("failed to open file: %v", err))
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
if _, err := f.WriteString(content); err != nil {
|
||||
return ErrorResult(fmt.Sprintf("failed to append to file: %v", err))
|
||||
}
|
||||
|
||||
return SilentResult(fmt.Sprintf("Appended to %s", path))
|
||||
}
|
||||
|
||||
// editFile reads the file via sysFs, performs the replacement, and writes back.
|
||||
// It uses a fileSystem interface, allowing the same logic for both restricted and unrestricted modes.
|
||||
func editFile(sysFs fileSystem, path, oldText, newText string) error {
|
||||
content, err := sysFs.ReadFile(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
newContent, err := replaceEditContent(content, oldText, newText)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return sysFs.WriteFile(path, newContent)
|
||||
}
|
||||
|
||||
// appendFile reads the existing content (if any) via sysFs, appends new content, and writes back.
|
||||
func appendFile(sysFs fileSystem, path, appendContent string) error {
|
||||
content, err := sysFs.ReadFile(path)
|
||||
if err != nil && !errors.Is(err, fs.ErrNotExist) {
|
||||
return err
|
||||
}
|
||||
|
||||
newContent := append(content, []byte(appendContent)...)
|
||||
return sysFs.WriteFile(path, newContent)
|
||||
}
|
||||
|
||||
// replaceEditContent handles the core logic of finding and replacing a single occurrence of oldText.
|
||||
func replaceEditContent(content []byte, oldText, newText string) ([]byte, error) {
|
||||
contentStr := string(content)
|
||||
|
||||
if !strings.Contains(contentStr, oldText) {
|
||||
return nil, fmt.Errorf("old_text not found in file. Make sure it matches exactly")
|
||||
}
|
||||
|
||||
count := strings.Count(contentStr, oldText)
|
||||
if count > 1 {
|
||||
return nil, fmt.Errorf("old_text appears %d times. Please provide more context to make it unique", count)
|
||||
}
|
||||
|
||||
newContent := strings.Replace(contentStr, oldText, newText, 1)
|
||||
return []byte(newContent), nil
|
||||
}
|
||||
|
||||
+170
-22
@@ -6,17 +6,19 @@ import (
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// TestEditTool_EditFile_Success verifies successful file editing
|
||||
func TestEditTool_EditFile_Success(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
testFile := filepath.Join(tmpDir, "test.txt")
|
||||
os.WriteFile(testFile, []byte("Hello World\nThis is a test"), 0644)
|
||||
os.WriteFile(testFile, []byte("Hello World\nThis is a test"), 0o644)
|
||||
|
||||
tool := NewEditFileTool(tmpDir, true)
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
args := map[string]any{
|
||||
"path": testFile,
|
||||
"old_text": "World",
|
||||
"new_text": "Universe",
|
||||
@@ -60,7 +62,7 @@ func TestEditTool_EditFile_NotFound(t *testing.T) {
|
||||
|
||||
tool := NewEditFileTool(tmpDir, true)
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
args := map[string]any{
|
||||
"path": testFile,
|
||||
"old_text": "old",
|
||||
"new_text": "new",
|
||||
@@ -83,11 +85,11 @@ func TestEditTool_EditFile_NotFound(t *testing.T) {
|
||||
func TestEditTool_EditFile_OldTextNotFound(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
testFile := filepath.Join(tmpDir, "test.txt")
|
||||
os.WriteFile(testFile, []byte("Hello World"), 0644)
|
||||
os.WriteFile(testFile, []byte("Hello World"), 0o644)
|
||||
|
||||
tool := NewEditFileTool(tmpDir, true)
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
args := map[string]any{
|
||||
"path": testFile,
|
||||
"old_text": "Goodbye",
|
||||
"new_text": "Hello",
|
||||
@@ -110,11 +112,11 @@ func TestEditTool_EditFile_OldTextNotFound(t *testing.T) {
|
||||
func TestEditTool_EditFile_MultipleMatches(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
testFile := filepath.Join(tmpDir, "test.txt")
|
||||
os.WriteFile(testFile, []byte("test test test"), 0644)
|
||||
os.WriteFile(testFile, []byte("test test test"), 0o644)
|
||||
|
||||
tool := NewEditFileTool(tmpDir, true)
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
args := map[string]any{
|
||||
"path": testFile,
|
||||
"old_text": "test",
|
||||
"new_text": "done",
|
||||
@@ -138,11 +140,11 @@ func TestEditTool_EditFile_OutsideAllowedDir(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
otherDir := t.TempDir()
|
||||
testFile := filepath.Join(otherDir, "test.txt")
|
||||
os.WriteFile(testFile, []byte("content"), 0644)
|
||||
os.WriteFile(testFile, []byte("content"), 0o644)
|
||||
|
||||
tool := NewEditFileTool(tmpDir, true) // Restrict to tmpDir
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
args := map[string]any{
|
||||
"path": testFile,
|
||||
"old_text": "content",
|
||||
"new_text": "new",
|
||||
@@ -151,21 +153,25 @@ func TestEditTool_EditFile_OutsideAllowedDir(t *testing.T) {
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Should return error result
|
||||
if !result.IsError {
|
||||
t.Errorf("Expected error when path is outside allowed directory")
|
||||
}
|
||||
assert.True(t, result.IsError, "Expected error when path is outside allowed directory")
|
||||
|
||||
// Should mention outside allowed directory
|
||||
if !strings.Contains(result.ForLLM, "outside") && !strings.Contains(result.ForUser, "outside") {
|
||||
t.Errorf("Expected 'outside allowed' message, got ForLLM: %s", result.ForLLM)
|
||||
}
|
||||
// Note: ErrorResult only sets ForLLM by default, so ForUser might be empty.
|
||||
// We check ForLLM as it's the primary error channel.
|
||||
assert.True(
|
||||
t,
|
||||
strings.Contains(result.ForLLM, "outside") || strings.Contains(result.ForLLM, "access denied") ||
|
||||
strings.Contains(result.ForLLM, "escapes"),
|
||||
"Expected 'outside allowed' or 'access denied' message, got ForLLM: %s",
|
||||
result.ForLLM,
|
||||
)
|
||||
}
|
||||
|
||||
// TestEditTool_EditFile_MissingPath verifies error handling for missing path
|
||||
func TestEditTool_EditFile_MissingPath(t *testing.T) {
|
||||
tool := NewEditFileTool("", false)
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
args := map[string]any{
|
||||
"old_text": "old",
|
||||
"new_text": "new",
|
||||
}
|
||||
@@ -182,7 +188,7 @@ func TestEditTool_EditFile_MissingPath(t *testing.T) {
|
||||
func TestEditTool_EditFile_MissingOldText(t *testing.T) {
|
||||
tool := NewEditFileTool("", false)
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
args := map[string]any{
|
||||
"path": "/tmp/test.txt",
|
||||
"new_text": "new",
|
||||
}
|
||||
@@ -199,7 +205,7 @@ func TestEditTool_EditFile_MissingOldText(t *testing.T) {
|
||||
func TestEditTool_EditFile_MissingNewText(t *testing.T) {
|
||||
tool := NewEditFileTool("", false)
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
args := map[string]any{
|
||||
"path": "/tmp/test.txt",
|
||||
"old_text": "old",
|
||||
}
|
||||
@@ -216,11 +222,11 @@ func TestEditTool_EditFile_MissingNewText(t *testing.T) {
|
||||
func TestEditTool_AppendFile_Success(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
testFile := filepath.Join(tmpDir, "test.txt")
|
||||
os.WriteFile(testFile, []byte("Initial content"), 0644)
|
||||
os.WriteFile(testFile, []byte("Initial content"), 0o644)
|
||||
|
||||
tool := NewAppendFileTool("", false)
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
args := map[string]any{
|
||||
"path": testFile,
|
||||
"content": "\nAppended content",
|
||||
}
|
||||
@@ -260,7 +266,7 @@ func TestEditTool_AppendFile_Success(t *testing.T) {
|
||||
func TestEditTool_AppendFile_MissingPath(t *testing.T) {
|
||||
tool := NewAppendFileTool("", false)
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
args := map[string]any{
|
||||
"content": "test",
|
||||
}
|
||||
|
||||
@@ -276,7 +282,7 @@ func TestEditTool_AppendFile_MissingPath(t *testing.T) {
|
||||
func TestEditTool_AppendFile_MissingContent(t *testing.T) {
|
||||
tool := NewAppendFileTool("", false)
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
args := map[string]any{
|
||||
"path": "/tmp/test.txt",
|
||||
}
|
||||
|
||||
@@ -287,3 +293,145 @@ func TestEditTool_AppendFile_MissingContent(t *testing.T) {
|
||||
t.Errorf("Expected error when content is missing")
|
||||
}
|
||||
}
|
||||
|
||||
// TestReplaceEditContent verifies the helper function replaceEditContent
|
||||
func TestReplaceEditContent(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
content []byte
|
||||
oldText string
|
||||
newText string
|
||||
expected []byte
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "successful replacement",
|
||||
content: []byte("hello world"),
|
||||
oldText: "world",
|
||||
newText: "universe",
|
||||
expected: []byte("hello universe"),
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "old text not found",
|
||||
content: []byte("hello world"),
|
||||
oldText: "golang",
|
||||
newText: "rust",
|
||||
expected: nil,
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "multiple matches found",
|
||||
content: []byte("test text test"),
|
||||
oldText: "test",
|
||||
newText: "done",
|
||||
expected: nil,
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := replaceEditContent(tt.content, tt.oldText, tt.newText)
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestAppendFileTool_AppendToNonExistent_Restricted verifies that AppendFileTool in restricted mode
|
||||
// can append to a file that does not yet exist — it should silently create the file.
|
||||
// This exercises the errors.Is(err, fs.ErrNotExist) path in appendFileWithRW + rootRW.
|
||||
func TestAppendFileTool_AppendToNonExistent_Restricted(t *testing.T) {
|
||||
workspace := t.TempDir()
|
||||
tool := NewAppendFileTool(workspace, true)
|
||||
ctx := context.Background()
|
||||
|
||||
args := map[string]any{
|
||||
"path": "brand_new_file.txt",
|
||||
"content": "first content",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
assert.False(
|
||||
t,
|
||||
result.IsError,
|
||||
"Expected success when appending to non-existent file in restricted mode, got: %s",
|
||||
result.ForLLM,
|
||||
)
|
||||
|
||||
// Verify the file was created with correct content
|
||||
data, err := os.ReadFile(filepath.Join(workspace, "brand_new_file.txt"))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "first content", string(data))
|
||||
}
|
||||
|
||||
// TestAppendFileTool_Restricted_Success verifies that AppendFileTool in restricted mode
|
||||
// correctly appends to an existing file within the sandbox.
|
||||
func TestAppendFileTool_Restricted_Success(t *testing.T) {
|
||||
workspace := t.TempDir()
|
||||
testFile := "existing.txt"
|
||||
err := os.WriteFile(filepath.Join(workspace, testFile), []byte("initial"), 0o644)
|
||||
assert.NoError(t, err)
|
||||
|
||||
tool := NewAppendFileTool(workspace, true)
|
||||
ctx := context.Background()
|
||||
args := map[string]any{
|
||||
"path": testFile,
|
||||
"content": " appended",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
assert.False(t, result.IsError, "Expected success, got: %s", result.ForLLM)
|
||||
assert.True(t, result.Silent)
|
||||
|
||||
data, err := os.ReadFile(filepath.Join(workspace, testFile))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "initial appended", string(data))
|
||||
}
|
||||
|
||||
// TestEditFileTool_Restricted_InPlaceEdit verifies that EditFileTool in restricted mode
|
||||
// correctly edits a file using the single-open editFileInRoot path.
|
||||
func TestEditFileTool_Restricted_InPlaceEdit(t *testing.T) {
|
||||
workspace := t.TempDir()
|
||||
testFile := "edit_target.txt"
|
||||
err := os.WriteFile(filepath.Join(workspace, testFile), []byte("Hello World"), 0o644)
|
||||
assert.NoError(t, err)
|
||||
|
||||
tool := NewEditFileTool(workspace, true)
|
||||
ctx := context.Background()
|
||||
args := map[string]any{
|
||||
"path": testFile,
|
||||
"old_text": "World",
|
||||
"new_text": "Go",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
assert.False(t, result.IsError, "Expected success, got: %s", result.ForLLM)
|
||||
assert.True(t, result.Silent)
|
||||
|
||||
data, err := os.ReadFile(filepath.Join(workspace, testFile))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "Hello Go", string(data))
|
||||
}
|
||||
|
||||
// TestEditFileTool_Restricted_FileNotFound verifies that editFileInRoot returns a proper
|
||||
// error message when the target file does not exist.
|
||||
func TestEditFileTool_Restricted_FileNotFound(t *testing.T) {
|
||||
workspace := t.TempDir()
|
||||
tool := NewEditFileTool(workspace, true)
|
||||
ctx := context.Background()
|
||||
args := map[string]any{
|
||||
"path": "no_such_file.txt",
|
||||
"old_text": "old",
|
||||
"new_text": "new",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
assert.True(t, result.IsError)
|
||||
assert.Contains(t, result.ForLLM, "not found")
|
||||
}
|
||||
|
||||
+216
-58
@@ -3,15 +3,17 @@ package tools
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// validatePath ensures the given path is within the workspace if restrict is true.
|
||||
func validatePath(path, workspace string, restrict bool) (string, error) {
|
||||
if workspace == "" {
|
||||
return path, nil
|
||||
return path, fmt.Errorf("workspace is not defined")
|
||||
}
|
||||
|
||||
absWorkspace, err := filepath.Abs(workspace)
|
||||
@@ -34,17 +36,19 @@ func validatePath(path, workspace string, restrict bool) (string, error) {
|
||||
return "", fmt.Errorf("access denied: path is outside the workspace")
|
||||
}
|
||||
|
||||
var resolved string
|
||||
workspaceReal := absWorkspace
|
||||
if resolved, err := filepath.EvalSymlinks(absWorkspace); err == nil {
|
||||
if resolved, err = filepath.EvalSymlinks(absWorkspace); err == nil {
|
||||
workspaceReal = resolved
|
||||
}
|
||||
|
||||
if resolved, err := filepath.EvalSymlinks(absPath); err == nil {
|
||||
if resolved, err = filepath.EvalSymlinks(absPath); err == nil {
|
||||
if !isWithinWorkspace(resolved, workspaceReal) {
|
||||
return "", fmt.Errorf("access denied: symlink resolves outside workspace")
|
||||
}
|
||||
} else if os.IsNotExist(err) {
|
||||
if parentResolved, err := resolveExistingAncestor(filepath.Dir(absPath)); err == nil {
|
||||
var parentResolved string
|
||||
if parentResolved, err = resolveExistingAncestor(filepath.Dir(absPath)); err == nil {
|
||||
if !isWithinWorkspace(parentResolved, workspaceReal) {
|
||||
return "", fmt.Errorf("access denied: symlink resolves outside workspace")
|
||||
}
|
||||
@@ -74,16 +78,21 @@ func resolveExistingAncestor(path string) (string, error) {
|
||||
|
||||
func isWithinWorkspace(candidate, workspace string) bool {
|
||||
rel, err := filepath.Rel(filepath.Clean(workspace), filepath.Clean(candidate))
|
||||
return err == nil && rel != ".." && !strings.HasPrefix(rel, ".."+string(os.PathSeparator))
|
||||
return err == nil && filepath.IsLocal(rel)
|
||||
}
|
||||
|
||||
type ReadFileTool struct {
|
||||
workspace string
|
||||
restrict bool
|
||||
fs fileSystem
|
||||
}
|
||||
|
||||
func NewReadFileTool(workspace string, restrict bool) *ReadFileTool {
|
||||
return &ReadFileTool{workspace: workspace, restrict: restrict}
|
||||
var fs fileSystem
|
||||
if restrict {
|
||||
fs = &sandboxFs{workspace: workspace}
|
||||
} else {
|
||||
fs = &hostFs{}
|
||||
}
|
||||
return &ReadFileTool{fs: fs}
|
||||
}
|
||||
|
||||
func (t *ReadFileTool) Name() string {
|
||||
@@ -94,11 +103,11 @@ func (t *ReadFileTool) Description() string {
|
||||
return "Read the contents of a file"
|
||||
}
|
||||
|
||||
func (t *ReadFileTool) Parameters() map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
func (t *ReadFileTool) Parameters() map[string]any {
|
||||
return map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"path": map[string]interface{}{
|
||||
"properties": map[string]any{
|
||||
"path": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Path to the file to read",
|
||||
},
|
||||
@@ -107,32 +116,31 @@ func (t *ReadFileTool) Parameters() map[string]interface{} {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *ReadFileTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
|
||||
func (t *ReadFileTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
|
||||
path, ok := args["path"].(string)
|
||||
if !ok {
|
||||
return ErrorResult("path is required")
|
||||
}
|
||||
|
||||
resolvedPath, err := validatePath(path, t.workspace, t.restrict)
|
||||
content, err := t.fs.ReadFile(path)
|
||||
if err != nil {
|
||||
return ErrorResult(err.Error())
|
||||
}
|
||||
|
||||
content, err := os.ReadFile(resolvedPath)
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("failed to read file: %v", err))
|
||||
}
|
||||
|
||||
return NewToolResult(string(content))
|
||||
}
|
||||
|
||||
type WriteFileTool struct {
|
||||
workspace string
|
||||
restrict bool
|
||||
fs fileSystem
|
||||
}
|
||||
|
||||
func NewWriteFileTool(workspace string, restrict bool) *WriteFileTool {
|
||||
return &WriteFileTool{workspace: workspace, restrict: restrict}
|
||||
var fs fileSystem
|
||||
if restrict {
|
||||
fs = &sandboxFs{workspace: workspace}
|
||||
} else {
|
||||
fs = &hostFs{}
|
||||
}
|
||||
return &WriteFileTool{fs: fs}
|
||||
}
|
||||
|
||||
func (t *WriteFileTool) Name() string {
|
||||
@@ -143,15 +151,15 @@ func (t *WriteFileTool) Description() string {
|
||||
return "Write content to a file"
|
||||
}
|
||||
|
||||
func (t *WriteFileTool) Parameters() map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
func (t *WriteFileTool) Parameters() map[string]any {
|
||||
return map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"path": map[string]interface{}{
|
||||
"properties": map[string]any{
|
||||
"path": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Path to the file to write",
|
||||
},
|
||||
"content": map[string]interface{}{
|
||||
"content": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Content to write to the file",
|
||||
},
|
||||
@@ -160,7 +168,7 @@ func (t *WriteFileTool) Parameters() map[string]interface{} {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *WriteFileTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
|
||||
func (t *WriteFileTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
|
||||
path, ok := args["path"].(string)
|
||||
if !ok {
|
||||
return ErrorResult("path is required")
|
||||
@@ -171,30 +179,25 @@ func (t *WriteFileTool) Execute(ctx context.Context, args map[string]interface{}
|
||||
return ErrorResult("content is required")
|
||||
}
|
||||
|
||||
resolvedPath, err := validatePath(path, t.workspace, t.restrict)
|
||||
if err != nil {
|
||||
if err := t.fs.WriteFile(path, []byte(content)); err != nil {
|
||||
return ErrorResult(err.Error())
|
||||
}
|
||||
|
||||
dir := filepath.Dir(resolvedPath)
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
return ErrorResult(fmt.Sprintf("failed to create directory: %v", err))
|
||||
}
|
||||
|
||||
if err := os.WriteFile(resolvedPath, []byte(content), 0644); err != nil {
|
||||
return ErrorResult(fmt.Sprintf("failed to write file: %v", err))
|
||||
}
|
||||
|
||||
return SilentResult(fmt.Sprintf("File written: %s", path))
|
||||
}
|
||||
|
||||
type ListDirTool struct {
|
||||
workspace string
|
||||
restrict bool
|
||||
fs fileSystem
|
||||
}
|
||||
|
||||
func NewListDirTool(workspace string, restrict bool) *ListDirTool {
|
||||
return &ListDirTool{workspace: workspace, restrict: restrict}
|
||||
var fs fileSystem
|
||||
if restrict {
|
||||
fs = &sandboxFs{workspace: workspace}
|
||||
} else {
|
||||
fs = &hostFs{}
|
||||
}
|
||||
return &ListDirTool{fs: fs}
|
||||
}
|
||||
|
||||
func (t *ListDirTool) Name() string {
|
||||
@@ -205,11 +208,11 @@ func (t *ListDirTool) Description() string {
|
||||
return "List files and directories in a path"
|
||||
}
|
||||
|
||||
func (t *ListDirTool) Parameters() map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
func (t *ListDirTool) Parameters() map[string]any {
|
||||
return map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"path": map[string]interface{}{
|
||||
"properties": map[string]any{
|
||||
"path": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Path to list",
|
||||
},
|
||||
@@ -218,30 +221,185 @@ func (t *ListDirTool) Parameters() map[string]interface{} {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *ListDirTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
|
||||
func (t *ListDirTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
|
||||
path, ok := args["path"].(string)
|
||||
if !ok {
|
||||
path = "."
|
||||
}
|
||||
|
||||
resolvedPath, err := validatePath(path, t.workspace, t.restrict)
|
||||
if err != nil {
|
||||
return ErrorResult(err.Error())
|
||||
}
|
||||
|
||||
entries, err := os.ReadDir(resolvedPath)
|
||||
entries, err := t.fs.ReadDir(path)
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("failed to read directory: %v", err))
|
||||
}
|
||||
return formatDirEntries(entries)
|
||||
}
|
||||
|
||||
result := ""
|
||||
func formatDirEntries(entries []os.DirEntry) *ToolResult {
|
||||
var result strings.Builder
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() {
|
||||
result += "DIR: " + entry.Name() + "\n"
|
||||
result.WriteString("DIR: " + entry.Name() + "\n")
|
||||
} else {
|
||||
result += "FILE: " + entry.Name() + "\n"
|
||||
result.WriteString("FILE: " + entry.Name() + "\n")
|
||||
}
|
||||
}
|
||||
return NewToolResult(result.String())
|
||||
}
|
||||
|
||||
// fileSystem abstracts reading, writing, and listing files, allowing both
|
||||
// unrestricted (host filesystem) and sandbox (os.Root) implementations to share the same polymorphic interface.
|
||||
type fileSystem interface {
|
||||
ReadFile(path string) ([]byte, error)
|
||||
WriteFile(path string, data []byte) error
|
||||
ReadDir(path string) ([]os.DirEntry, error)
|
||||
}
|
||||
|
||||
// hostFs is an unrestricted fileReadWriter that operates directly on the host filesystem.
|
||||
type hostFs struct{}
|
||||
|
||||
func (h *hostFs) ReadFile(path string) ([]byte, error) {
|
||||
content, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil, fmt.Errorf("failed to read file: file not found: %w", err)
|
||||
}
|
||||
if os.IsPermission(err) {
|
||||
return nil, fmt.Errorf("failed to read file: access denied: %w", err)
|
||||
}
|
||||
return nil, fmt.Errorf("failed to read file: %w", err)
|
||||
}
|
||||
return content, nil
|
||||
}
|
||||
|
||||
func (h *hostFs) ReadDir(path string) ([]os.DirEntry, error) {
|
||||
return os.ReadDir(path)
|
||||
}
|
||||
|
||||
func (h *hostFs) WriteFile(path string, data []byte) error {
|
||||
dir := filepath.Dir(path)
|
||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||
return fmt.Errorf("failed to create parent directories: %w", err)
|
||||
}
|
||||
|
||||
// We use a "write-then-rename" pattern here to ensure an atomic write.
|
||||
// This prevents the target file from being left in a truncated or partial state
|
||||
// if the operation is interrupted, as the rename operation is atomic on Linux.
|
||||
tmpPath := fmt.Sprintf("%s.%d.tmp", path, time.Now().UnixNano())
|
||||
if err := os.WriteFile(tmpPath, data, 0o644); err != nil {
|
||||
os.Remove(tmpPath) // Ensure cleanup of partial/empty temp file
|
||||
return fmt.Errorf("failed to write temp file: %w", err)
|
||||
}
|
||||
|
||||
if err := os.Rename(tmpPath, path); err != nil {
|
||||
os.Remove(tmpPath)
|
||||
return fmt.Errorf("failed to replace original file: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// sandboxFs is a sandboxed fileSystem that operates within a strictly defined workspace using os.Root.
|
||||
type sandboxFs struct {
|
||||
workspace string
|
||||
}
|
||||
|
||||
func (r *sandboxFs) execute(path string, fn func(root *os.Root, relPath string) error) error {
|
||||
if r.workspace == "" {
|
||||
return fmt.Errorf("workspace is not defined")
|
||||
}
|
||||
|
||||
root, err := os.OpenRoot(r.workspace)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open workspace: %w", err)
|
||||
}
|
||||
defer root.Close()
|
||||
|
||||
relPath, err := getSafeRelPath(r.workspace, path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return fn(root, relPath)
|
||||
}
|
||||
|
||||
func (r *sandboxFs) ReadFile(path string) ([]byte, error) {
|
||||
var content []byte
|
||||
err := r.execute(path, func(root *os.Root, relPath string) error {
|
||||
fileContent, err := root.ReadFile(relPath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return fmt.Errorf("failed to read file: file not found: %w", err)
|
||||
}
|
||||
// os.Root returns "escapes from parent" for paths outside the root
|
||||
if os.IsPermission(err) || strings.Contains(err.Error(), "escapes from parent") ||
|
||||
strings.Contains(err.Error(), "permission denied") {
|
||||
return fmt.Errorf("failed to read file: access denied: %w", err)
|
||||
}
|
||||
return fmt.Errorf("failed to read file: %w", err)
|
||||
}
|
||||
content = fileContent
|
||||
return nil
|
||||
})
|
||||
return content, err
|
||||
}
|
||||
|
||||
func (r *sandboxFs) WriteFile(path string, data []byte) error {
|
||||
return r.execute(path, func(root *os.Root, relPath string) error {
|
||||
dir := filepath.Dir(relPath)
|
||||
if dir != "." && dir != "/" {
|
||||
if err := root.MkdirAll(dir, 0o755); err != nil {
|
||||
return fmt.Errorf("failed to create parent directories: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// We use a "write-then-rename" pattern here to ensure an atomic write.
|
||||
// This prevents the target file from being left in a truncated or partial state
|
||||
// if the operation is interrupted, as the rename operation is atomic on Linux.
|
||||
tmpRelPath := fmt.Sprintf("%s.%d.tmp", relPath, time.Now().UnixNano())
|
||||
|
||||
if err := root.WriteFile(tmpRelPath, data, 0o644); err != nil {
|
||||
root.Remove(tmpRelPath) // Ensure cleanup of partial/empty temp file
|
||||
return fmt.Errorf("failed to write to temp file: %w", err)
|
||||
}
|
||||
|
||||
if err := root.Rename(tmpRelPath, relPath); err != nil {
|
||||
root.Remove(tmpRelPath)
|
||||
return fmt.Errorf("failed to rename temp file over target: %w", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (r *sandboxFs) ReadDir(path string) ([]os.DirEntry, error) {
|
||||
var entries []os.DirEntry
|
||||
err := r.execute(path, func(root *os.Root, relPath string) error {
|
||||
dirEntries, err := fs.ReadDir(root.FS(), relPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
entries = dirEntries
|
||||
return nil
|
||||
})
|
||||
return entries, err
|
||||
}
|
||||
|
||||
// Helper to get a safe relative path for os.Root usage
|
||||
func getSafeRelPath(workspace, path string) (string, error) {
|
||||
if workspace == "" {
|
||||
return "", fmt.Errorf("workspace is not defined")
|
||||
}
|
||||
|
||||
rel := filepath.Clean(path)
|
||||
if filepath.IsAbs(rel) {
|
||||
var err error
|
||||
rel, err = filepath.Rel(workspace, rel)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to calculate relative path: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return NewToolResult(result)
|
||||
if !filepath.IsLocal(rel) {
|
||||
return "", fmt.Errorf("path escapes workspace: %s", path)
|
||||
}
|
||||
|
||||
return rel, nil
|
||||
}
|
||||
|
||||
+236
-29
@@ -2,21 +2,24 @@ package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// TestFilesystemTool_ReadFile_Success verifies successful file reading
|
||||
func TestFilesystemTool_ReadFile_Success(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
testFile := filepath.Join(tmpDir, "test.txt")
|
||||
os.WriteFile(testFile, []byte("test content"), 0644)
|
||||
os.WriteFile(testFile, []byte("test content"), 0o644)
|
||||
|
||||
tool := &ReadFileTool{}
|
||||
tool := NewReadFileTool("", false)
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
args := map[string]any{
|
||||
"path": testFile,
|
||||
}
|
||||
|
||||
@@ -41,9 +44,9 @@ func TestFilesystemTool_ReadFile_Success(t *testing.T) {
|
||||
|
||||
// TestFilesystemTool_ReadFile_NotFound verifies error handling for missing file
|
||||
func TestFilesystemTool_ReadFile_NotFound(t *testing.T) {
|
||||
tool := &ReadFileTool{}
|
||||
tool := NewReadFileTool("", false)
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
args := map[string]any{
|
||||
"path": "/nonexistent_file_12345.txt",
|
||||
}
|
||||
|
||||
@@ -64,7 +67,7 @@ func TestFilesystemTool_ReadFile_NotFound(t *testing.T) {
|
||||
func TestFilesystemTool_ReadFile_MissingPath(t *testing.T) {
|
||||
tool := &ReadFileTool{}
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{}
|
||||
args := map[string]any{}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
@@ -84,9 +87,9 @@ func TestFilesystemTool_WriteFile_Success(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
testFile := filepath.Join(tmpDir, "newfile.txt")
|
||||
|
||||
tool := &WriteFileTool{}
|
||||
tool := NewWriteFileTool("", false)
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
args := map[string]any{
|
||||
"path": testFile,
|
||||
"content": "hello world",
|
||||
}
|
||||
@@ -123,9 +126,9 @@ func TestFilesystemTool_WriteFile_CreateDir(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
testFile := filepath.Join(tmpDir, "subdir", "newfile.txt")
|
||||
|
||||
tool := &WriteFileTool{}
|
||||
tool := NewWriteFileTool("", false)
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
args := map[string]any{
|
||||
"path": testFile,
|
||||
"content": "test",
|
||||
}
|
||||
@@ -149,9 +152,9 @@ func TestFilesystemTool_WriteFile_CreateDir(t *testing.T) {
|
||||
|
||||
// TestFilesystemTool_WriteFile_MissingPath verifies error handling for missing path
|
||||
func TestFilesystemTool_WriteFile_MissingPath(t *testing.T) {
|
||||
tool := &WriteFileTool{}
|
||||
tool := NewWriteFileTool("", false)
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
args := map[string]any{
|
||||
"content": "test",
|
||||
}
|
||||
|
||||
@@ -165,9 +168,9 @@ func TestFilesystemTool_WriteFile_MissingPath(t *testing.T) {
|
||||
|
||||
// TestFilesystemTool_WriteFile_MissingContent verifies error handling for missing content
|
||||
func TestFilesystemTool_WriteFile_MissingContent(t *testing.T) {
|
||||
tool := &WriteFileTool{}
|
||||
tool := NewWriteFileTool("", false)
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
args := map[string]any{
|
||||
"path": "/tmp/test.txt",
|
||||
}
|
||||
|
||||
@@ -179,7 +182,8 @@ func TestFilesystemTool_WriteFile_MissingContent(t *testing.T) {
|
||||
}
|
||||
|
||||
// Should mention required parameter
|
||||
if !strings.Contains(result.ForLLM, "content is required") && !strings.Contains(result.ForUser, "content is required") {
|
||||
if !strings.Contains(result.ForLLM, "content is required") &&
|
||||
!strings.Contains(result.ForUser, "content is required") {
|
||||
t.Errorf("Expected 'content is required' message, got ForLLM: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
@@ -187,13 +191,13 @@ func TestFilesystemTool_WriteFile_MissingContent(t *testing.T) {
|
||||
// TestFilesystemTool_ListDir_Success verifies successful directory listing
|
||||
func TestFilesystemTool_ListDir_Success(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
os.WriteFile(filepath.Join(tmpDir, "file1.txt"), []byte("content"), 0644)
|
||||
os.WriteFile(filepath.Join(tmpDir, "file2.txt"), []byte("content"), 0644)
|
||||
os.Mkdir(filepath.Join(tmpDir, "subdir"), 0755)
|
||||
os.WriteFile(filepath.Join(tmpDir, "file1.txt"), []byte("content"), 0o644)
|
||||
os.WriteFile(filepath.Join(tmpDir, "file2.txt"), []byte("content"), 0o644)
|
||||
os.Mkdir(filepath.Join(tmpDir, "subdir"), 0o755)
|
||||
|
||||
tool := &ListDirTool{}
|
||||
tool := NewListDirTool("", false)
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
args := map[string]any{
|
||||
"path": tmpDir,
|
||||
}
|
||||
|
||||
@@ -215,9 +219,9 @@ func TestFilesystemTool_ListDir_Success(t *testing.T) {
|
||||
|
||||
// TestFilesystemTool_ListDir_NotFound verifies error handling for non-existent directory
|
||||
func TestFilesystemTool_ListDir_NotFound(t *testing.T) {
|
||||
tool := &ListDirTool{}
|
||||
tool := NewListDirTool("", false)
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
args := map[string]any{
|
||||
"path": "/nonexistent_directory_12345",
|
||||
}
|
||||
|
||||
@@ -236,9 +240,9 @@ func TestFilesystemTool_ListDir_NotFound(t *testing.T) {
|
||||
|
||||
// TestFilesystemTool_ListDir_DefaultPath verifies default to current directory
|
||||
func TestFilesystemTool_ListDir_DefaultPath(t *testing.T) {
|
||||
tool := &ListDirTool{}
|
||||
tool := NewListDirTool("", false)
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{}
|
||||
args := map[string]any{}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
@@ -250,15 +254,14 @@ func TestFilesystemTool_ListDir_DefaultPath(t *testing.T) {
|
||||
|
||||
// Block paths that look inside workspace but point outside via symlink.
|
||||
func TestFilesystemTool_ReadFile_RejectsSymlinkEscape(t *testing.T) {
|
||||
|
||||
root := t.TempDir()
|
||||
workspace := filepath.Join(root, "workspace")
|
||||
if err := os.MkdirAll(workspace, 0755); err != nil {
|
||||
if err := os.MkdirAll(workspace, 0o755); err != nil {
|
||||
t.Fatalf("failed to create workspace: %v", err)
|
||||
}
|
||||
|
||||
secret := filepath.Join(root, "secret.txt")
|
||||
if err := os.WriteFile(secret, []byte("top secret"), 0644); err != nil {
|
||||
if err := os.WriteFile(secret, []byte("top secret"), 0o644); err != nil {
|
||||
t.Fatalf("failed to write secret file: %v", err)
|
||||
}
|
||||
|
||||
@@ -268,14 +271,218 @@ func TestFilesystemTool_ReadFile_RejectsSymlinkEscape(t *testing.T) {
|
||||
}
|
||||
|
||||
tool := NewReadFileTool(workspace, true)
|
||||
result := tool.Execute(context.Background(), map[string]interface{}{
|
||||
result := tool.Execute(context.Background(), map[string]any{
|
||||
"path": link,
|
||||
})
|
||||
|
||||
if !result.IsError {
|
||||
t.Fatalf("expected symlink escape to be blocked")
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "symlink resolves outside workspace") {
|
||||
// os.Root might return different errors depending on platform/implementation
|
||||
// but it definitely should error.
|
||||
// Our wrapper returns "access denied or file not found"
|
||||
if !strings.Contains(result.ForLLM, "access denied") && !strings.Contains(result.ForLLM, "file not found") &&
|
||||
!strings.Contains(result.ForLLM, "no such file") {
|
||||
t.Fatalf("expected symlink escape error, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilesystemTool_EmptyWorkspace_AccessDenied(t *testing.T) {
|
||||
tool := NewReadFileTool("", true) // restrict=true but workspace=""
|
||||
|
||||
// Try to read a sensitive file (simulated by a temp file outside workspace)
|
||||
tmpDir := t.TempDir()
|
||||
secretFile := filepath.Join(tmpDir, "shadow")
|
||||
os.WriteFile(secretFile, []byte("secret data"), 0o600)
|
||||
|
||||
result := tool.Execute(context.Background(), map[string]any{
|
||||
"path": secretFile,
|
||||
})
|
||||
|
||||
// We EXPECT IsError=true (access blocked due to empty workspace)
|
||||
assert.True(t, result.IsError, "Security Regression: Empty workspace allowed access! content: %s", result.ForLLM)
|
||||
|
||||
// Verify it failed for the right reason
|
||||
assert.Contains(t, result.ForLLM, "workspace is not defined", "Expected 'workspace is not defined' error")
|
||||
}
|
||||
|
||||
// TestRootMkdirAll verifies that root.MkdirAll (used by atomicWriteFileInRoot) handles all cases:
|
||||
// single dir, deeply nested dirs, already-existing dirs, and a file blocking a directory path.
|
||||
func TestRootMkdirAll(t *testing.T) {
|
||||
workspace := t.TempDir()
|
||||
root, err := os.OpenRoot(workspace)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to open root: %v", err)
|
||||
}
|
||||
defer root.Close()
|
||||
|
||||
// Case 1: Single directory
|
||||
err = root.MkdirAll("dir1", 0o755)
|
||||
assert.NoError(t, err)
|
||||
_, err = os.Stat(filepath.Join(workspace, "dir1"))
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Case 2: Deeply nested directory
|
||||
err = root.MkdirAll("a/b/c/d", 0o755)
|
||||
assert.NoError(t, err)
|
||||
_, err = os.Stat(filepath.Join(workspace, "a/b/c/d"))
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Case 3: Already exists — must be idempotent
|
||||
err = root.MkdirAll("a/b/c/d", 0o755)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Case 4: A regular file blocks directory creation — must error
|
||||
err = os.WriteFile(filepath.Join(workspace, "file_exists"), []byte("data"), 0o644)
|
||||
assert.NoError(t, err)
|
||||
err = root.MkdirAll("file_exists", 0o755)
|
||||
assert.Error(t, err, "expected error when a file exists at the directory path")
|
||||
}
|
||||
|
||||
func TestFilesystemTool_WriteFile_Restricted_CreateDir(t *testing.T) {
|
||||
workspace := t.TempDir()
|
||||
tool := NewWriteFileTool(workspace, true)
|
||||
ctx := context.Background()
|
||||
|
||||
testFile := "deep/nested/path/to/file.txt"
|
||||
content := "deep content"
|
||||
args := map[string]any{
|
||||
"path": testFile,
|
||||
"content": content,
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
assert.False(t, result.IsError, "Expected success, got: %s", result.ForLLM)
|
||||
|
||||
// Verify file content
|
||||
actualPath := filepath.Join(workspace, testFile)
|
||||
data, err := os.ReadFile(actualPath)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, content, string(data))
|
||||
}
|
||||
|
||||
// TestHostRW_Read_PermissionDenied verifies that hostRW.Read surfaces access denied errors.
|
||||
func TestHostRW_Read_PermissionDenied(t *testing.T) {
|
||||
if os.Getuid() == 0 {
|
||||
t.Skip("skipping permission test: running as root")
|
||||
}
|
||||
tmpDir := t.TempDir()
|
||||
protected := filepath.Join(tmpDir, "protected.txt")
|
||||
err := os.WriteFile(protected, []byte("secret"), 0o000)
|
||||
assert.NoError(t, err)
|
||||
defer os.Chmod(protected, 0o644) // ensure cleanup
|
||||
|
||||
_, err = (&hostFs{}).ReadFile(protected)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "access denied")
|
||||
}
|
||||
|
||||
// TestHostRW_Read_Directory verifies that hostRW.Read returns an error when given a directory path.
|
||||
func TestHostRW_Read_Directory(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
_, err := (&hostFs{}).ReadFile(tmpDir)
|
||||
assert.Error(t, err, "expected error when reading a directory as a file")
|
||||
}
|
||||
|
||||
// TestRootRW_Read_Directory verifies that rootRW.Read returns an error when given a directory.
|
||||
func TestRootRW_Read_Directory(t *testing.T) {
|
||||
workspace := t.TempDir()
|
||||
root, err := os.OpenRoot(workspace)
|
||||
assert.NoError(t, err)
|
||||
defer root.Close()
|
||||
|
||||
// Create a subdirectory
|
||||
err = root.Mkdir("subdir", 0o755)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = (&sandboxFs{workspace: workspace}).ReadFile("subdir")
|
||||
assert.Error(t, err, "expected error when reading a directory as a file")
|
||||
}
|
||||
|
||||
// TestHostRW_Write_ParentDirMissing verifies that hostRW.Write creates parent dirs automatically.
|
||||
func TestHostRW_Write_ParentDirMissing(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
target := filepath.Join(tmpDir, "a", "b", "c", "file.txt")
|
||||
|
||||
err := (&hostFs{}).WriteFile(target, []byte("hello"))
|
||||
assert.NoError(t, err)
|
||||
|
||||
data, err := os.ReadFile(target)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "hello", string(data))
|
||||
}
|
||||
|
||||
// TestRootRW_Write_ParentDirMissing verifies that rootRW.Write creates
|
||||
// nested parent directories automatically within the sandbox.
|
||||
func TestRootRW_Write_ParentDirMissing(t *testing.T) {
|
||||
workspace := t.TempDir()
|
||||
|
||||
relPath := "x/y/z/file.txt"
|
||||
err := (&sandboxFs{workspace: workspace}).WriteFile(relPath, []byte("nested"))
|
||||
assert.NoError(t, err)
|
||||
|
||||
data, err := os.ReadFile(filepath.Join(workspace, relPath))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "nested", string(data))
|
||||
}
|
||||
|
||||
// TestHostRW_Write verifies the hostRW.Write helper function
|
||||
func TestHostRW_Write(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
testFile := filepath.Join(tmpDir, "atomic_test.txt")
|
||||
testData := []byte("atomic test content")
|
||||
|
||||
err := (&hostFs{}).WriteFile(testFile, testData)
|
||||
assert.NoError(t, err)
|
||||
|
||||
content, err := os.ReadFile(testFile)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, testData, content)
|
||||
|
||||
// Verify it overwrites correctly
|
||||
newData := []byte("new atomic content")
|
||||
err = (&hostFs{}).WriteFile(testFile, newData)
|
||||
assert.NoError(t, err)
|
||||
|
||||
content, err = os.ReadFile(testFile)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, newData, content)
|
||||
}
|
||||
|
||||
// TestRootRW_Write verifies the rootRW.Write helper function
|
||||
func TestRootRW_Write(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
relPath := "atomic_root_test.txt"
|
||||
testData := []byte("atomic root test content")
|
||||
|
||||
erw := &sandboxFs{workspace: tmpDir}
|
||||
err := erw.WriteFile(relPath, testData)
|
||||
assert.NoError(t, err)
|
||||
|
||||
root, err := os.OpenRoot(tmpDir)
|
||||
assert.NoError(t, err)
|
||||
defer root.Close()
|
||||
|
||||
f, err := root.Open(relPath)
|
||||
assert.NoError(t, err)
|
||||
defer f.Close()
|
||||
|
||||
content, err := io.ReadAll(f)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, testData, content)
|
||||
|
||||
// Verify it overwrites correctly
|
||||
newData := []byte("new root atomic content")
|
||||
err = erw.WriteFile(relPath, newData)
|
||||
assert.NoError(t, err)
|
||||
|
||||
f2, err := root.Open(relPath)
|
||||
assert.NoError(t, err)
|
||||
defer f2.Close()
|
||||
|
||||
content, err = io.ReadAll(f2)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, newData, content)
|
||||
}
|
||||
|
||||
+25
-15
@@ -24,37 +24,37 @@ func (t *I2CTool) Description() string {
|
||||
return "Interact with I2C bus devices for reading sensors and controlling peripherals. Actions: detect (list buses), scan (find devices on a bus), read (read bytes from device), write (send bytes to device). Linux only."
|
||||
}
|
||||
|
||||
func (t *I2CTool) Parameters() map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
func (t *I2CTool) Parameters() map[string]any {
|
||||
return map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"action": map[string]interface{}{
|
||||
"properties": map[string]any{
|
||||
"action": map[string]any{
|
||||
"type": "string",
|
||||
"enum": []string{"detect", "scan", "read", "write"},
|
||||
"description": "Action to perform: detect (list available I2C buses), scan (find devices on a bus), read (read bytes from a device), write (send bytes to a device)",
|
||||
},
|
||||
"bus": map[string]interface{}{
|
||||
"bus": map[string]any{
|
||||
"type": "string",
|
||||
"description": "I2C bus number (e.g. \"1\" for /dev/i2c-1). Required for scan/read/write.",
|
||||
},
|
||||
"address": map[string]interface{}{
|
||||
"address": map[string]any{
|
||||
"type": "integer",
|
||||
"description": "7-bit I2C device address (0x03-0x77). Required for read/write.",
|
||||
},
|
||||
"register": map[string]interface{}{
|
||||
"register": map[string]any{
|
||||
"type": "integer",
|
||||
"description": "Register address to read from or write to. If set, sends register byte before read/write.",
|
||||
},
|
||||
"data": map[string]interface{}{
|
||||
"data": map[string]any{
|
||||
"type": "array",
|
||||
"items": map[string]interface{}{"type": "integer"},
|
||||
"items": map[string]any{"type": "integer"},
|
||||
"description": "Bytes to write (0-255 each). Required for write action.",
|
||||
},
|
||||
"length": map[string]interface{}{
|
||||
"length": map[string]any{
|
||||
"type": "integer",
|
||||
"description": "Number of bytes to read (1-256). Default: 1. Used with read action.",
|
||||
},
|
||||
"confirm": map[string]interface{}{
|
||||
"confirm": map[string]any{
|
||||
"type": "boolean",
|
||||
"description": "Must be true for write operations. Safety guard to prevent accidental writes.",
|
||||
},
|
||||
@@ -63,7 +63,7 @@ func (t *I2CTool) Parameters() map[string]interface{} {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *I2CTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
|
||||
func (t *I2CTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
|
||||
if runtime.GOOS != "linux" {
|
||||
return ErrorResult("I2C is only supported on Linux. This tool requires /dev/i2c-* device files.")
|
||||
}
|
||||
@@ -95,7 +95,9 @@ func (t *I2CTool) detect() *ToolResult {
|
||||
}
|
||||
|
||||
if len(matches) == 0 {
|
||||
return SilentResult("No I2C buses found. You may need to:\n1. Load the i2c-dev module: modprobe i2c-dev\n2. Check that I2C is enabled in device tree\n3. Configure pinmux for your board (see hardware skill)")
|
||||
return SilentResult(
|
||||
"No I2C buses found. You may need to:\n1. Load the i2c-dev module: modprobe i2c-dev\n2. Check that I2C is enabled in device tree\n3. Configure pinmux for your board (see hardware skill)",
|
||||
)
|
||||
}
|
||||
|
||||
type busInfo struct {
|
||||
@@ -115,14 +117,20 @@ func (t *I2CTool) detect() *ToolResult {
|
||||
return SilentResult(fmt.Sprintf("Found %d I2C bus(es):\n%s", len(buses), string(result)))
|
||||
}
|
||||
|
||||
// Helper functions for I2C operations (used by platform-specific implementations)
|
||||
|
||||
// isValidBusID checks that a bus identifier is a simple number (prevents path injection)
|
||||
//
|
||||
//nolint:unused // Used by i2c_linux.go
|
||||
func isValidBusID(id string) bool {
|
||||
matched, _ := regexp.MatchString(`^\d+$`, id)
|
||||
return matched
|
||||
}
|
||||
|
||||
// parseI2CAddress extracts and validates an I2C address from args
|
||||
func parseI2CAddress(args map[string]interface{}) (int, *ToolResult) {
|
||||
//
|
||||
//nolint:unused // Used by i2c_linux.go
|
||||
func parseI2CAddress(args map[string]any) (int, *ToolResult) {
|
||||
addrFloat, ok := args["address"].(float64)
|
||||
if !ok {
|
||||
return 0, ErrorResult("address is required (e.g. 0x38 for AHT20)")
|
||||
@@ -135,7 +143,9 @@ func parseI2CAddress(args map[string]interface{}) (int, *ToolResult) {
|
||||
}
|
||||
|
||||
// parseI2CBus extracts and validates an I2C bus from args
|
||||
func parseI2CBus(args map[string]interface{}) (string, *ToolResult) {
|
||||
//
|
||||
//nolint:unused // Used by i2c_linux.go
|
||||
func parseI2CBus(args map[string]any) (string, *ToolResult) {
|
||||
bus, ok := args["bus"].(string)
|
||||
if !ok || bus == "" {
|
||||
return "", ErrorResult("bus is required (e.g. \"1\" for /dev/i2c-1)")
|
||||
|
||||
+13
-9
@@ -74,7 +74,7 @@ func smbusProbe(fd int, addr int, hasQuick bool) bool {
|
||||
// scan probes valid 7-bit addresses on a bus for connected devices.
|
||||
// Uses the same hybrid probe strategy as i2cdetect's MODE_AUTO:
|
||||
// SMBus Quick Write for most addresses, SMBus Read Byte for EEPROM ranges.
|
||||
func (t *I2CTool) scan(args map[string]interface{}) *ToolResult {
|
||||
func (t *I2CTool) scan(args map[string]any) *ToolResult {
|
||||
bus, errResult := parseI2CBus(args)
|
||||
if errResult != nil {
|
||||
return errResult
|
||||
@@ -99,7 +99,9 @@ func (t *I2CTool) scan(args map[string]interface{}) *ToolResult {
|
||||
hasReadByte := funcs&i2cFuncSmbusReadByte != 0
|
||||
|
||||
if !hasQuick && !hasReadByte {
|
||||
return ErrorResult(fmt.Sprintf("I2C adapter %s supports neither SMBus Quick nor Read Byte — cannot probe safely", devPath))
|
||||
return ErrorResult(
|
||||
fmt.Sprintf("I2C adapter %s supports neither SMBus Quick nor Read Byte — cannot probe safely", devPath),
|
||||
)
|
||||
}
|
||||
|
||||
type deviceEntry struct {
|
||||
@@ -133,7 +135,7 @@ func (t *I2CTool) scan(args map[string]interface{}) *ToolResult {
|
||||
return SilentResult(fmt.Sprintf("No devices found on %s. Check wiring and pull-up resistors.", devPath))
|
||||
}
|
||||
|
||||
result, _ := json.MarshalIndent(map[string]interface{}{
|
||||
result, _ := json.MarshalIndent(map[string]any{
|
||||
"bus": devPath,
|
||||
"devices": found,
|
||||
"count": len(found),
|
||||
@@ -142,7 +144,7 @@ func (t *I2CTool) scan(args map[string]interface{}) *ToolResult {
|
||||
}
|
||||
|
||||
// readDevice reads bytes from an I2C device, optionally at a specific register
|
||||
func (t *I2CTool) readDevice(args map[string]interface{}) *ToolResult {
|
||||
func (t *I2CTool) readDevice(args map[string]any) *ToolResult {
|
||||
bus, errResult := parseI2CBus(args)
|
||||
if errResult != nil {
|
||||
return errResult
|
||||
@@ -180,7 +182,7 @@ func (t *I2CTool) readDevice(args map[string]interface{}) *ToolResult {
|
||||
if reg < 0 || reg > 255 {
|
||||
return ErrorResult("register must be between 0x00 and 0xFF")
|
||||
}
|
||||
_, err := syscall.Write(fd, []byte{byte(reg)})
|
||||
_, err = syscall.Write(fd, []byte{byte(reg)})
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("failed to write register 0x%02x: %v", reg, err))
|
||||
}
|
||||
@@ -201,7 +203,7 @@ func (t *I2CTool) readDevice(args map[string]interface{}) *ToolResult {
|
||||
intBytes[i] = int(buf[i])
|
||||
}
|
||||
|
||||
result, _ := json.MarshalIndent(map[string]interface{}{
|
||||
result, _ := json.MarshalIndent(map[string]any{
|
||||
"bus": devPath,
|
||||
"address": fmt.Sprintf("0x%02x", addr),
|
||||
"bytes": intBytes,
|
||||
@@ -212,10 +214,12 @@ func (t *I2CTool) readDevice(args map[string]interface{}) *ToolResult {
|
||||
}
|
||||
|
||||
// writeDevice writes bytes to an I2C device, optionally at a specific register
|
||||
func (t *I2CTool) writeDevice(args map[string]interface{}) *ToolResult {
|
||||
func (t *I2CTool) writeDevice(args map[string]any) *ToolResult {
|
||||
confirm, _ := args["confirm"].(bool)
|
||||
if !confirm {
|
||||
return ErrorResult("write operations require confirm: true. Please confirm with the user before writing to I2C devices, as incorrect writes can misconfigure hardware.")
|
||||
return ErrorResult(
|
||||
"write operations require confirm: true. Please confirm with the user before writing to I2C devices, as incorrect writes can misconfigure hardware.",
|
||||
)
|
||||
}
|
||||
|
||||
bus, errResult := parseI2CBus(args)
|
||||
@@ -228,7 +232,7 @@ func (t *I2CTool) writeDevice(args map[string]interface{}) *ToolResult {
|
||||
return errResult
|
||||
}
|
||||
|
||||
dataRaw, ok := args["data"].([]interface{})
|
||||
dataRaw, ok := args["data"].([]any)
|
||||
if !ok || len(dataRaw) == 0 {
|
||||
return ErrorResult("data is required for write (array of byte values 0-255)")
|
||||
}
|
||||
|
||||
@@ -3,16 +3,16 @@
|
||||
package tools
|
||||
|
||||
// scan is a stub for non-Linux platforms.
|
||||
func (t *I2CTool) scan(args map[string]interface{}) *ToolResult {
|
||||
func (t *I2CTool) scan(args map[string]any) *ToolResult {
|
||||
return ErrorResult("I2C is only supported on Linux")
|
||||
}
|
||||
|
||||
// readDevice is a stub for non-Linux platforms.
|
||||
func (t *I2CTool) readDevice(args map[string]interface{}) *ToolResult {
|
||||
func (t *I2CTool) readDevice(args map[string]any) *ToolResult {
|
||||
return ErrorResult("I2C is only supported on Linux")
|
||||
}
|
||||
|
||||
// writeDevice is a stub for non-Linux platforms.
|
||||
func (t *I2CTool) writeDevice(args map[string]interface{}) *ToolResult {
|
||||
func (t *I2CTool) writeDevice(args map[string]any) *ToolResult {
|
||||
return ErrorResult("I2C is only supported on Linux")
|
||||
}
|
||||
|
||||
@@ -26,19 +26,19 @@ func (t *MessageTool) Description() string {
|
||||
return "Send a message to user on a chat channel. Use this when you want to communicate something."
|
||||
}
|
||||
|
||||
func (t *MessageTool) Parameters() map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
func (t *MessageTool) Parameters() map[string]any {
|
||||
return map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"content": map[string]interface{}{
|
||||
"properties": map[string]any{
|
||||
"content": map[string]any{
|
||||
"type": "string",
|
||||
"description": "The message content to send",
|
||||
},
|
||||
"channel": map[string]interface{}{
|
||||
"channel": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Optional: target channel (telegram, whatsapp, etc.)",
|
||||
},
|
||||
"chat_id": map[string]interface{}{
|
||||
"chat_id": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Optional: target chat/user ID",
|
||||
},
|
||||
@@ -62,7 +62,7 @@ func (t *MessageTool) SetSendCallback(callback SendCallback) {
|
||||
t.sendCallback = callback
|
||||
}
|
||||
|
||||
func (t *MessageTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
|
||||
func (t *MessageTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
|
||||
content, ok := args["content"].(string)
|
||||
if !ok {
|
||||
return &ToolResult{ForLLM: "content is required", IsError: true}
|
||||
|
||||
+10
-10
@@ -19,7 +19,7 @@ func TestMessageTool_Execute_Success(t *testing.T) {
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
args := map[string]any{
|
||||
"content": "Hello, world!",
|
||||
}
|
||||
|
||||
@@ -70,7 +70,7 @@ func TestMessageTool_Execute_WithCustomChannel(t *testing.T) {
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
args := map[string]any{
|
||||
"content": "Test message",
|
||||
"channel": "custom-channel",
|
||||
"chat_id": "custom-chat-id",
|
||||
@@ -104,7 +104,7 @@ func TestMessageTool_Execute_SendFailure(t *testing.T) {
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
args := map[string]any{
|
||||
"content": "Test message",
|
||||
}
|
||||
|
||||
@@ -136,7 +136,7 @@ func TestMessageTool_Execute_MissingContent(t *testing.T) {
|
||||
tool.SetContext("test-channel", "test-chat-id")
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{} // content missing
|
||||
args := map[string]any{} // content missing
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
@@ -158,7 +158,7 @@ func TestMessageTool_Execute_NoTargetChannel(t *testing.T) {
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
args := map[string]any{
|
||||
"content": "Test message",
|
||||
}
|
||||
|
||||
@@ -179,7 +179,7 @@ func TestMessageTool_Execute_NotConfigured(t *testing.T) {
|
||||
// No SetSendCallback called
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
args := map[string]any{
|
||||
"content": "Test message",
|
||||
}
|
||||
|
||||
@@ -219,7 +219,7 @@ func TestMessageTool_Parameters(t *testing.T) {
|
||||
t.Error("Expected type 'object'")
|
||||
}
|
||||
|
||||
props, ok := params["properties"].(map[string]interface{})
|
||||
props, ok := params["properties"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("Expected properties to be a map")
|
||||
}
|
||||
@@ -231,7 +231,7 @@ func TestMessageTool_Parameters(t *testing.T) {
|
||||
}
|
||||
|
||||
// Check content property
|
||||
contentProp, ok := props["content"].(map[string]interface{})
|
||||
contentProp, ok := props["content"].(map[string]any)
|
||||
if !ok {
|
||||
t.Error("Expected 'content' property")
|
||||
}
|
||||
@@ -240,7 +240,7 @@ func TestMessageTool_Parameters(t *testing.T) {
|
||||
}
|
||||
|
||||
// Check channel property (optional)
|
||||
channelProp, ok := props["channel"].(map[string]interface{})
|
||||
channelProp, ok := props["channel"].(map[string]any)
|
||||
if !ok {
|
||||
t.Error("Expected 'channel' property")
|
||||
}
|
||||
@@ -249,7 +249,7 @@ func TestMessageTool_Parameters(t *testing.T) {
|
||||
}
|
||||
|
||||
// Check chat_id property (optional)
|
||||
chatIDProp, ok := props["chat_id"].(map[string]interface{})
|
||||
chatIDProp, ok := props["chat_id"].(map[string]any)
|
||||
if !ok {
|
||||
t.Error("Expected 'chat_id' property")
|
||||
}
|
||||
|
||||
+44
-23
@@ -3,6 +3,7 @@ package tools
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sort"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -34,16 +35,22 @@ func (r *ToolRegistry) Get(name string) (Tool, bool) {
|
||||
return tool, ok
|
||||
}
|
||||
|
||||
func (r *ToolRegistry) Execute(ctx context.Context, name string, args map[string]interface{}) *ToolResult {
|
||||
func (r *ToolRegistry) Execute(ctx context.Context, name string, args map[string]any) *ToolResult {
|
||||
return r.ExecuteWithContext(ctx, name, args, "", "", nil)
|
||||
}
|
||||
|
||||
// ExecuteWithContext executes a tool with channel/chatID context and optional async callback.
|
||||
// If the tool implements AsyncTool and a non-nil callback is provided,
|
||||
// the callback will be set on the tool before execution.
|
||||
func (r *ToolRegistry) ExecuteWithContext(ctx context.Context, name string, args map[string]interface{}, channel, chatID string, asyncCallback AsyncCallback) *ToolResult {
|
||||
func (r *ToolRegistry) ExecuteWithContext(
|
||||
ctx context.Context,
|
||||
name string,
|
||||
args map[string]any,
|
||||
channel, chatID string,
|
||||
asyncCallback AsyncCallback,
|
||||
) *ToolResult {
|
||||
logger.InfoCF("tool", "Tool execution started",
|
||||
map[string]interface{}{
|
||||
map[string]any{
|
||||
"tool": name,
|
||||
"args": args,
|
||||
})
|
||||
@@ -51,7 +58,7 @@ func (r *ToolRegistry) ExecuteWithContext(ctx context.Context, name string, args
|
||||
tool, ok := r.Get(name)
|
||||
if !ok {
|
||||
logger.ErrorCF("tool", "Tool not found",
|
||||
map[string]interface{}{
|
||||
map[string]any{
|
||||
"tool": name,
|
||||
})
|
||||
return ErrorResult(fmt.Sprintf("tool %q not found", name)).WithError(fmt.Errorf("tool not found"))
|
||||
@@ -66,7 +73,7 @@ func (r *ToolRegistry) ExecuteWithContext(ctx context.Context, name string, args
|
||||
if asyncTool, ok := tool.(AsyncTool); ok && asyncCallback != nil {
|
||||
asyncTool.SetCallback(asyncCallback)
|
||||
logger.DebugCF("tool", "Async callback injected",
|
||||
map[string]interface{}{
|
||||
map[string]any{
|
||||
"tool": name,
|
||||
})
|
||||
}
|
||||
@@ -78,20 +85,20 @@ func (r *ToolRegistry) ExecuteWithContext(ctx context.Context, name string, args
|
||||
// Log based on result type
|
||||
if result.IsError {
|
||||
logger.ErrorCF("tool", "Tool execution failed",
|
||||
map[string]interface{}{
|
||||
map[string]any{
|
||||
"tool": name,
|
||||
"duration": duration.Milliseconds(),
|
||||
"error": result.ForLLM,
|
||||
})
|
||||
} else if result.Async {
|
||||
logger.InfoCF("tool", "Tool started (async)",
|
||||
map[string]interface{}{
|
||||
map[string]any{
|
||||
"tool": name,
|
||||
"duration": duration.Milliseconds(),
|
||||
})
|
||||
} else {
|
||||
logger.InfoCF("tool", "Tool execution completed",
|
||||
map[string]interface{}{
|
||||
map[string]any{
|
||||
"tool": name,
|
||||
"duration_ms": duration.Milliseconds(),
|
||||
"result_length": len(result.ForLLM),
|
||||
@@ -101,13 +108,27 @@ func (r *ToolRegistry) ExecuteWithContext(ctx context.Context, name string, args
|
||||
return result
|
||||
}
|
||||
|
||||
func (r *ToolRegistry) GetDefinitions() []map[string]interface{} {
|
||||
// sortedToolNames returns tool names in sorted order for deterministic iteration.
|
||||
// This is critical for KV cache stability: non-deterministic map iteration would
|
||||
// produce different system prompts and tool definitions on each call, invalidating
|
||||
// the LLM's prefix cache even when no tools have changed.
|
||||
func (r *ToolRegistry) sortedToolNames() []string {
|
||||
names := make([]string, 0, len(r.tools))
|
||||
for name := range r.tools {
|
||||
names = append(names, name)
|
||||
}
|
||||
sort.Strings(names)
|
||||
return names
|
||||
}
|
||||
|
||||
func (r *ToolRegistry) GetDefinitions() []map[string]any {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
definitions := make([]map[string]interface{}, 0, len(r.tools))
|
||||
for _, tool := range r.tools {
|
||||
definitions = append(definitions, ToolToSchema(tool))
|
||||
sorted := r.sortedToolNames()
|
||||
definitions := make([]map[string]any, 0, len(sorted))
|
||||
for _, name := range sorted {
|
||||
definitions = append(definitions, ToolToSchema(r.tools[name]))
|
||||
}
|
||||
return definitions
|
||||
}
|
||||
@@ -118,19 +139,21 @@ func (r *ToolRegistry) ToProviderDefs() []providers.ToolDefinition {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
definitions := make([]providers.ToolDefinition, 0, len(r.tools))
|
||||
for _, tool := range r.tools {
|
||||
sorted := r.sortedToolNames()
|
||||
definitions := make([]providers.ToolDefinition, 0, len(sorted))
|
||||
for _, name := range sorted {
|
||||
tool := r.tools[name]
|
||||
schema := ToolToSchema(tool)
|
||||
|
||||
// Safely extract nested values with type checks
|
||||
fn, ok := schema["function"].(map[string]interface{})
|
||||
fn, ok := schema["function"].(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
name, _ := fn["name"].(string)
|
||||
desc, _ := fn["description"].(string)
|
||||
params, _ := fn["parameters"].(map[string]interface{})
|
||||
params, _ := fn["parameters"].(map[string]any)
|
||||
|
||||
definitions = append(definitions, providers.ToolDefinition{
|
||||
Type: "function",
|
||||
@@ -149,11 +172,7 @@ func (r *ToolRegistry) List() []string {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
names := make([]string, 0, len(r.tools))
|
||||
for name := range r.tools {
|
||||
names = append(names, name)
|
||||
}
|
||||
return names
|
||||
return r.sortedToolNames()
|
||||
}
|
||||
|
||||
// Count returns the number of registered tools.
|
||||
@@ -169,8 +188,10 @@ func (r *ToolRegistry) GetSummaries() []string {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
summaries := make([]string, 0, len(r.tools))
|
||||
for _, tool := range r.tools {
|
||||
sorted := r.sortedToolNames()
|
||||
summaries := make([]string, 0, len(sorted))
|
||||
for _, name := range sorted {
|
||||
tool := r.tools[name]
|
||||
summaries = append(summaries, fmt.Sprintf("- `%s` - %s", tool.Name(), tool.Description()))
|
||||
}
|
||||
return summaries
|
||||
|
||||
@@ -0,0 +1,350 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
// --- mock types ---
|
||||
|
||||
type mockRegistryTool struct {
|
||||
name string
|
||||
desc string
|
||||
params map[string]any
|
||||
result *ToolResult
|
||||
}
|
||||
|
||||
func (m *mockRegistryTool) Name() string { return m.name }
|
||||
func (m *mockRegistryTool) Description() string { return m.desc }
|
||||
func (m *mockRegistryTool) Parameters() map[string]any { return m.params }
|
||||
func (m *mockRegistryTool) Execute(_ context.Context, _ map[string]any) *ToolResult {
|
||||
return m.result
|
||||
}
|
||||
|
||||
type mockCtxTool struct {
|
||||
mockRegistryTool
|
||||
channel string
|
||||
chatID string
|
||||
}
|
||||
|
||||
func (m *mockCtxTool) SetContext(channel, chatID string) {
|
||||
m.channel = channel
|
||||
m.chatID = chatID
|
||||
}
|
||||
|
||||
type mockAsyncRegistryTool struct {
|
||||
mockRegistryTool
|
||||
cb AsyncCallback
|
||||
}
|
||||
|
||||
func (m *mockAsyncRegistryTool) SetCallback(cb AsyncCallback) {
|
||||
m.cb = cb
|
||||
}
|
||||
|
||||
// --- helpers ---
|
||||
|
||||
func newMockTool(name, desc string) *mockRegistryTool {
|
||||
return &mockRegistryTool{
|
||||
name: name,
|
||||
desc: desc,
|
||||
params: map[string]any{"type": "object"},
|
||||
result: SilentResult("ok"),
|
||||
}
|
||||
}
|
||||
|
||||
// --- tests ---
|
||||
|
||||
func TestNewToolRegistry(t *testing.T) {
|
||||
r := NewToolRegistry()
|
||||
if r.Count() != 0 {
|
||||
t.Errorf("expected empty registry, got count %d", r.Count())
|
||||
}
|
||||
if len(r.List()) != 0 {
|
||||
t.Errorf("expected empty list, got %v", r.List())
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolRegistry_RegisterAndGet(t *testing.T) {
|
||||
r := NewToolRegistry()
|
||||
tool := newMockTool("echo", "echoes input")
|
||||
r.Register(tool)
|
||||
|
||||
got, ok := r.Get("echo")
|
||||
if !ok {
|
||||
t.Fatal("expected to find registered tool")
|
||||
}
|
||||
if got.Name() != "echo" {
|
||||
t.Errorf("expected name 'echo', got %q", got.Name())
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolRegistry_Get_NotFound(t *testing.T) {
|
||||
r := NewToolRegistry()
|
||||
_, ok := r.Get("nonexistent")
|
||||
if ok {
|
||||
t.Error("expected ok=false for unregistered tool")
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolRegistry_RegisterOverwrite(t *testing.T) {
|
||||
r := NewToolRegistry()
|
||||
r.Register(newMockTool("dup", "first"))
|
||||
r.Register(newMockTool("dup", "second"))
|
||||
|
||||
if r.Count() != 1 {
|
||||
t.Errorf("expected count 1 after overwrite, got %d", r.Count())
|
||||
}
|
||||
tool, _ := r.Get("dup")
|
||||
if tool.Description() != "second" {
|
||||
t.Errorf("expected overwritten description 'second', got %q", tool.Description())
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolRegistry_Execute_Success(t *testing.T) {
|
||||
r := NewToolRegistry()
|
||||
r.Register(&mockRegistryTool{
|
||||
name: "greet",
|
||||
desc: "says hello",
|
||||
params: map[string]any{},
|
||||
result: SilentResult("hello"),
|
||||
})
|
||||
|
||||
result := r.Execute(context.Background(), "greet", nil)
|
||||
if result.IsError {
|
||||
t.Errorf("expected success, got error: %s", result.ForLLM)
|
||||
}
|
||||
if result.ForLLM != "hello" {
|
||||
t.Errorf("expected ForLLM 'hello', got %q", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolRegistry_Execute_NotFound(t *testing.T) {
|
||||
r := NewToolRegistry()
|
||||
result := r.Execute(context.Background(), "missing", nil)
|
||||
if !result.IsError {
|
||||
t.Error("expected error for missing tool")
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "not found") {
|
||||
t.Errorf("expected 'not found' in error, got %q", result.ForLLM)
|
||||
}
|
||||
if result.Err == nil {
|
||||
t.Error("expected Err to be set via WithError")
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolRegistry_ExecuteWithContext_ContextualTool(t *testing.T) {
|
||||
r := NewToolRegistry()
|
||||
ct := &mockCtxTool{
|
||||
mockRegistryTool: *newMockTool("ctx_tool", "needs context"),
|
||||
}
|
||||
r.Register(ct)
|
||||
|
||||
r.ExecuteWithContext(context.Background(), "ctx_tool", nil, "telegram", "chat-42", nil)
|
||||
|
||||
if ct.channel != "telegram" {
|
||||
t.Errorf("expected channel 'telegram', got %q", ct.channel)
|
||||
}
|
||||
if ct.chatID != "chat-42" {
|
||||
t.Errorf("expected chatID 'chat-42', got %q", ct.chatID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolRegistry_ExecuteWithContext_SkipsEmptyContext(t *testing.T) {
|
||||
r := NewToolRegistry()
|
||||
ct := &mockCtxTool{
|
||||
mockRegistryTool: *newMockTool("ctx_tool", "needs context"),
|
||||
}
|
||||
r.Register(ct)
|
||||
|
||||
r.ExecuteWithContext(context.Background(), "ctx_tool", nil, "", "", nil)
|
||||
|
||||
if ct.channel != "" || ct.chatID != "" {
|
||||
t.Error("SetContext should not be called with empty channel/chatID")
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolRegistry_ExecuteWithContext_AsyncCallback(t *testing.T) {
|
||||
r := NewToolRegistry()
|
||||
at := &mockAsyncRegistryTool{
|
||||
mockRegistryTool: *newMockTool("async_tool", "async work"),
|
||||
}
|
||||
at.result = AsyncResult("started")
|
||||
r.Register(at)
|
||||
|
||||
called := false
|
||||
cb := func(_ context.Context, _ *ToolResult) { called = true }
|
||||
|
||||
result := r.ExecuteWithContext(context.Background(), "async_tool", nil, "", "", cb)
|
||||
if at.cb == nil {
|
||||
t.Error("expected SetCallback to have been called")
|
||||
}
|
||||
if !result.Async {
|
||||
t.Error("expected async result")
|
||||
}
|
||||
|
||||
at.cb(context.Background(), SilentResult("done"))
|
||||
if !called {
|
||||
t.Error("expected callback to be invoked")
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolRegistry_GetDefinitions(t *testing.T) {
|
||||
r := NewToolRegistry()
|
||||
r.Register(newMockTool("alpha", "tool A"))
|
||||
|
||||
defs := r.GetDefinitions()
|
||||
if len(defs) != 1 {
|
||||
t.Fatalf("expected 1 definition, got %d", len(defs))
|
||||
}
|
||||
if defs[0]["type"] != "function" {
|
||||
t.Errorf("expected type 'function', got %v", defs[0]["type"])
|
||||
}
|
||||
fn, ok := defs[0]["function"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("expected 'function' key to be a map")
|
||||
}
|
||||
if fn["name"] != "alpha" {
|
||||
t.Errorf("expected name 'alpha', got %v", fn["name"])
|
||||
}
|
||||
if fn["description"] != "tool A" {
|
||||
t.Errorf("expected description 'tool A', got %v", fn["description"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolRegistry_ToProviderDefs(t *testing.T) {
|
||||
r := NewToolRegistry()
|
||||
params := map[string]any{"type": "object", "properties": map[string]any{}}
|
||||
r.Register(&mockRegistryTool{
|
||||
name: "beta",
|
||||
desc: "tool B",
|
||||
params: params,
|
||||
result: SilentResult("ok"),
|
||||
})
|
||||
|
||||
defs := r.ToProviderDefs()
|
||||
if len(defs) != 1 {
|
||||
t.Fatalf("expected 1 provider def, got %d", len(defs))
|
||||
}
|
||||
|
||||
want := providers.ToolDefinition{
|
||||
Type: "function",
|
||||
Function: providers.ToolFunctionDefinition{
|
||||
Name: "beta",
|
||||
Description: "tool B",
|
||||
Parameters: params,
|
||||
},
|
||||
}
|
||||
got := defs[0]
|
||||
if got.Type != want.Type {
|
||||
t.Errorf("Type: want %q, got %q", want.Type, got.Type)
|
||||
}
|
||||
if got.Function.Name != want.Function.Name {
|
||||
t.Errorf("Name: want %q, got %q", want.Function.Name, got.Function.Name)
|
||||
}
|
||||
if got.Function.Description != want.Function.Description {
|
||||
t.Errorf("Description: want %q, got %q", want.Function.Description, got.Function.Description)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolRegistry_List(t *testing.T) {
|
||||
r := NewToolRegistry()
|
||||
r.Register(newMockTool("x", ""))
|
||||
r.Register(newMockTool("y", ""))
|
||||
|
||||
names := r.List()
|
||||
if len(names) != 2 {
|
||||
t.Fatalf("expected 2 names, got %d", len(names))
|
||||
}
|
||||
|
||||
nameSet := map[string]bool{}
|
||||
for _, n := range names {
|
||||
nameSet[n] = true
|
||||
}
|
||||
if !nameSet["x"] || !nameSet["y"] {
|
||||
t.Errorf("expected names {x, y}, got %v", names)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolRegistry_Count(t *testing.T) {
|
||||
r := NewToolRegistry()
|
||||
if r.Count() != 0 {
|
||||
t.Errorf("expected 0, got %d", r.Count())
|
||||
}
|
||||
|
||||
r.Register(newMockTool("a", ""))
|
||||
r.Register(newMockTool("b", ""))
|
||||
if r.Count() != 2 {
|
||||
t.Errorf("expected 2, got %d", r.Count())
|
||||
}
|
||||
|
||||
r.Register(newMockTool("a", "replaced"))
|
||||
if r.Count() != 2 {
|
||||
t.Errorf("expected 2 after overwrite, got %d", r.Count())
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolRegistry_GetSummaries(t *testing.T) {
|
||||
r := NewToolRegistry()
|
||||
r.Register(newMockTool("read_file", "Reads a file"))
|
||||
|
||||
summaries := r.GetSummaries()
|
||||
if len(summaries) != 1 {
|
||||
t.Fatalf("expected 1 summary, got %d", len(summaries))
|
||||
}
|
||||
if !strings.Contains(summaries[0], "`read_file`") {
|
||||
t.Errorf("expected backtick-quoted name in summary, got %q", summaries[0])
|
||||
}
|
||||
if !strings.Contains(summaries[0], "Reads a file") {
|
||||
t.Errorf("expected description in summary, got %q", summaries[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolToSchema(t *testing.T) {
|
||||
tool := newMockTool("demo", "demo tool")
|
||||
schema := ToolToSchema(tool)
|
||||
|
||||
if schema["type"] != "function" {
|
||||
t.Errorf("expected type 'function', got %v", schema["type"])
|
||||
}
|
||||
fn, ok := schema["function"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("expected 'function' to be a map")
|
||||
}
|
||||
if fn["name"] != "demo" {
|
||||
t.Errorf("expected name 'demo', got %v", fn["name"])
|
||||
}
|
||||
if fn["description"] != "demo tool" {
|
||||
t.Errorf("expected description 'demo tool', got %v", fn["description"])
|
||||
}
|
||||
if fn["parameters"] == nil {
|
||||
t.Error("expected parameters to be set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolRegistry_ConcurrentAccess(t *testing.T) {
|
||||
r := NewToolRegistry()
|
||||
var wg sync.WaitGroup
|
||||
|
||||
for i := 0; i < 50; i++ {
|
||||
wg.Add(1)
|
||||
go func(n int) {
|
||||
defer wg.Done()
|
||||
name := string(rune('A' + n%26))
|
||||
r.Register(newMockTool(name, "concurrent"))
|
||||
r.Get(name)
|
||||
r.Count()
|
||||
r.List()
|
||||
r.GetDefinitions()
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
if r.Count() == 0 {
|
||||
t.Error("expected tools to be registered after concurrent access")
|
||||
}
|
||||
}
|
||||
@@ -192,7 +192,7 @@ func TestToolResultJSONStructure(t *testing.T) {
|
||||
}
|
||||
|
||||
// Verify JSON structure
|
||||
var parsed map[string]interface{}
|
||||
var parsed map[string]any
|
||||
if err := json.Unmarshal(data, &parsed); err != nil {
|
||||
t.Fatalf("Failed to parse JSON: %v", err)
|
||||
}
|
||||
|
||||
+45
-13
@@ -3,6 +3,7 @@ package tools
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
@@ -75,11 +76,11 @@ func NewExecTool(workingDir string, restrict bool) *ExecTool {
|
||||
func NewExecToolWithConfig(workingDir string, restrict bool, config *config.Config) *ExecTool {
|
||||
denyPatterns := make([]*regexp.Regexp, 0)
|
||||
|
||||
enableDenyPatterns := true
|
||||
if config != nil {
|
||||
execConfig := config.Tools.Exec
|
||||
enableDenyPatterns = execConfig.EnableDenyPatterns
|
||||
enableDenyPatterns := execConfig.EnableDenyPatterns
|
||||
if enableDenyPatterns {
|
||||
denyPatterns = append(denyPatterns, defaultDenyPatterns...)
|
||||
if len(execConfig.CustomDenyPatterns) > 0 {
|
||||
fmt.Printf("Using custom deny patterns: %v\n", execConfig.CustomDenyPatterns)
|
||||
for _, pattern := range execConfig.CustomDenyPatterns {
|
||||
@@ -90,8 +91,6 @@ func NewExecToolWithConfig(workingDir string, restrict bool, config *config.Conf
|
||||
}
|
||||
denyPatterns = append(denyPatterns, re)
|
||||
}
|
||||
} else {
|
||||
denyPatterns = append(denyPatterns, defaultDenyPatterns...)
|
||||
}
|
||||
} else {
|
||||
// If deny patterns are disabled, we won't add any patterns, allowing all commands.
|
||||
@@ -118,15 +117,15 @@ func (t *ExecTool) Description() string {
|
||||
return "Execute a shell command and return its output. Use with caution."
|
||||
}
|
||||
|
||||
func (t *ExecTool) Parameters() map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
func (t *ExecTool) Parameters() map[string]any {
|
||||
return map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"command": map[string]interface{}{
|
||||
"properties": map[string]any{
|
||||
"command": map[string]any{
|
||||
"type": "string",
|
||||
"description": "The shell command to execute",
|
||||
},
|
||||
"working_dir": map[string]interface{}{
|
||||
"working_dir": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Optional working directory for the command",
|
||||
},
|
||||
@@ -135,7 +134,7 @@ func (t *ExecTool) Parameters() map[string]interface{} {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *ExecTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
|
||||
func (t *ExecTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
|
||||
command, ok := args["command"].(string)
|
||||
if !ok {
|
||||
return ErrorResult("command is required")
|
||||
@@ -143,7 +142,15 @@ func (t *ExecTool) Execute(ctx context.Context, args map[string]interface{}) *To
|
||||
|
||||
cwd := t.workingDir
|
||||
if wd, ok := args["working_dir"].(string); ok && wd != "" {
|
||||
cwd = wd
|
||||
if t.restrictToWorkspace && t.workingDir != "" {
|
||||
resolvedWD, err := validatePath(wd, t.workingDir, true)
|
||||
if err != nil {
|
||||
return ErrorResult("Command blocked by safety guard (" + err.Error() + ")")
|
||||
}
|
||||
cwd = resolvedWD
|
||||
} else {
|
||||
cwd = wd
|
||||
}
|
||||
}
|
||||
|
||||
if cwd == "" {
|
||||
@@ -177,18 +184,43 @@ func (t *ExecTool) Execute(ctx context.Context, args map[string]interface{}) *To
|
||||
cmd.Dir = cwd
|
||||
}
|
||||
|
||||
prepareCommandForTermination(cmd)
|
||||
|
||||
var stdout, stderr bytes.Buffer
|
||||
cmd.Stdout = &stdout
|
||||
cmd.Stderr = &stderr
|
||||
|
||||
err := cmd.Run()
|
||||
if err := cmd.Start(); err != nil {
|
||||
return ErrorResult(fmt.Sprintf("failed to start command: %v", err))
|
||||
}
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- cmd.Wait()
|
||||
}()
|
||||
|
||||
var err error
|
||||
select {
|
||||
case err = <-done:
|
||||
case <-cmdCtx.Done():
|
||||
_ = terminateProcessTree(cmd)
|
||||
select {
|
||||
case err = <-done:
|
||||
case <-time.After(2 * time.Second):
|
||||
if cmd.Process != nil {
|
||||
_ = cmd.Process.Kill()
|
||||
}
|
||||
err = <-done
|
||||
}
|
||||
}
|
||||
|
||||
output := stdout.String()
|
||||
if stderr.Len() > 0 {
|
||||
output += "\nSTDERR:\n" + stderr.String()
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if cmdCtx.Err() == context.DeadlineExceeded {
|
||||
if errors.Is(cmdCtx.Err(), context.DeadlineExceeded) {
|
||||
msg := fmt.Sprintf("Command timed out after %v", t.timeout)
|
||||
return &ToolResult{
|
||||
ForLLM: msg,
|
||||
|
||||
@@ -0,0 +1,32 @@
|
||||
//go:build !windows
|
||||
|
||||
package tools
|
||||
|
||||
import (
|
||||
"os/exec"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
func prepareCommandForTermination(cmd *exec.Cmd) {
|
||||
if cmd == nil {
|
||||
return
|
||||
}
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
|
||||
}
|
||||
|
||||
func terminateProcessTree(cmd *exec.Cmd) error {
|
||||
if cmd == nil || cmd.Process == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
pid := cmd.Process.Pid
|
||||
if pid <= 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Kill the entire process group spawned by the shell command.
|
||||
_ = syscall.Kill(-pid, syscall.SIGKILL)
|
||||
// Fallback kill on the shell process itself.
|
||||
_ = cmd.Process.Kill()
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,27 @@
|
||||
//go:build windows
|
||||
|
||||
package tools
|
||||
|
||||
import (
|
||||
"os/exec"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
func prepareCommandForTermination(cmd *exec.Cmd) {
|
||||
// no-op on Windows
|
||||
}
|
||||
|
||||
func terminateProcessTree(cmd *exec.Cmd) error {
|
||||
if cmd == nil || cmd.Process == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
pid := cmd.Process.Pid
|
||||
if pid <= 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
_ = exec.Command("taskkill", "/T", "/F", "/PID", strconv.Itoa(pid)).Run()
|
||||
_ = cmd.Process.Kill()
|
||||
return nil
|
||||
}
|
||||
+75
-11
@@ -14,7 +14,7 @@ func TestShellTool_Success(t *testing.T) {
|
||||
tool := NewExecTool("", false)
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
args := map[string]any{
|
||||
"command": "echo 'hello world'",
|
||||
}
|
||||
|
||||
@@ -41,7 +41,7 @@ func TestShellTool_Failure(t *testing.T) {
|
||||
tool := NewExecTool("", false)
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
args := map[string]any{
|
||||
"command": "ls /nonexistent_directory_12345",
|
||||
}
|
||||
|
||||
@@ -69,7 +69,7 @@ func TestShellTool_Timeout(t *testing.T) {
|
||||
tool.SetTimeout(100 * time.Millisecond)
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
args := map[string]any{
|
||||
"command": "sleep 10",
|
||||
}
|
||||
|
||||
@@ -91,12 +91,12 @@ func TestShellTool_WorkingDir(t *testing.T) {
|
||||
// Create temp directory
|
||||
tmpDir := t.TempDir()
|
||||
testFile := filepath.Join(tmpDir, "test.txt")
|
||||
os.WriteFile(testFile, []byte("test content"), 0644)
|
||||
os.WriteFile(testFile, []byte("test content"), 0o644)
|
||||
|
||||
tool := NewExecTool("", false)
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
args := map[string]any{
|
||||
"command": "cat test.txt",
|
||||
"working_dir": tmpDir,
|
||||
}
|
||||
@@ -117,7 +117,7 @@ func TestShellTool_DangerousCommand(t *testing.T) {
|
||||
tool := NewExecTool("", false)
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
args := map[string]any{
|
||||
"command": "rm -rf /",
|
||||
}
|
||||
|
||||
@@ -138,7 +138,7 @@ func TestShellTool_MissingCommand(t *testing.T) {
|
||||
tool := NewExecTool("", false)
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{}
|
||||
args := map[string]any{}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
@@ -153,7 +153,7 @@ func TestShellTool_StderrCapture(t *testing.T) {
|
||||
tool := NewExecTool("", false)
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
args := map[string]any{
|
||||
"command": "sh -c 'echo stdout; echo stderr >&2'",
|
||||
}
|
||||
|
||||
@@ -174,7 +174,7 @@ func TestShellTool_OutputTruncation(t *testing.T) {
|
||||
|
||||
ctx := context.Background()
|
||||
// Generate long output (>10000 chars)
|
||||
args := map[string]interface{}{
|
||||
args := map[string]any{
|
||||
"command": "python3 -c \"print('x' * 20000)\" || echo " + strings.Repeat("x", 20000),
|
||||
}
|
||||
|
||||
@@ -186,6 +186,66 @@ func TestShellTool_OutputTruncation(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestShellTool_WorkingDir_OutsideWorkspace verifies that working_dir cannot escape the workspace directly
|
||||
func TestShellTool_WorkingDir_OutsideWorkspace(t *testing.T) {
|
||||
root := t.TempDir()
|
||||
workspace := filepath.Join(root, "workspace")
|
||||
outsideDir := filepath.Join(root, "outside")
|
||||
if err := os.MkdirAll(workspace, 0o755); err != nil {
|
||||
t.Fatalf("failed to create workspace: %v", err)
|
||||
}
|
||||
if err := os.MkdirAll(outsideDir, 0o755); err != nil {
|
||||
t.Fatalf("failed to create outside dir: %v", err)
|
||||
}
|
||||
|
||||
tool := NewExecTool(workspace, true)
|
||||
result := tool.Execute(context.Background(), map[string]any{
|
||||
"command": "pwd",
|
||||
"working_dir": outsideDir,
|
||||
})
|
||||
|
||||
if !result.IsError {
|
||||
t.Fatalf("expected working_dir outside workspace to be blocked, got output: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "blocked") {
|
||||
t.Errorf("expected 'blocked' in error, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
// TestShellTool_WorkingDir_SymlinkEscape verifies that a symlink inside the workspace
|
||||
// pointing outside cannot be used as working_dir to escape the sandbox.
|
||||
func TestShellTool_WorkingDir_SymlinkEscape(t *testing.T) {
|
||||
root := t.TempDir()
|
||||
workspace := filepath.Join(root, "workspace")
|
||||
secretDir := filepath.Join(root, "secret")
|
||||
if err := os.MkdirAll(workspace, 0o755); err != nil {
|
||||
t.Fatalf("failed to create workspace: %v", err)
|
||||
}
|
||||
if err := os.MkdirAll(secretDir, 0o755); err != nil {
|
||||
t.Fatalf("failed to create secret dir: %v", err)
|
||||
}
|
||||
os.WriteFile(filepath.Join(secretDir, "secret.txt"), []byte("top secret"), 0o644)
|
||||
|
||||
// symlink lives inside the workspace but resolves to secretDir outside it
|
||||
link := filepath.Join(workspace, "escape")
|
||||
if err := os.Symlink(secretDir, link); err != nil {
|
||||
t.Skipf("symlinks not supported in this environment: %v", err)
|
||||
}
|
||||
|
||||
tool := NewExecTool(workspace, true)
|
||||
result := tool.Execute(context.Background(), map[string]any{
|
||||
"command": "cat secret.txt",
|
||||
"working_dir": link,
|
||||
})
|
||||
|
||||
if !result.IsError {
|
||||
t.Fatalf("expected symlink working_dir escape to be blocked, got output: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "blocked") {
|
||||
t.Errorf("expected 'blocked' in error, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
// TestShellTool_RestrictToWorkspace verifies workspace restriction
|
||||
func TestShellTool_RestrictToWorkspace(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
@@ -193,7 +253,7 @@ func TestShellTool_RestrictToWorkspace(t *testing.T) {
|
||||
tool.SetRestrictToWorkspace(true)
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
args := map[string]any{
|
||||
"command": "cat ../../etc/passwd",
|
||||
}
|
||||
|
||||
@@ -205,6 +265,10 @@ func TestShellTool_RestrictToWorkspace(t *testing.T) {
|
||||
}
|
||||
|
||||
if !strings.Contains(result.ForLLM, "blocked") && !strings.Contains(result.ForUser, "blocked") {
|
||||
t.Errorf("Expected 'blocked' message for path traversal, got ForLLM: %s, ForUser: %s", result.ForLLM, result.ForUser)
|
||||
t.Errorf(
|
||||
"Expected 'blocked' message for path traversal, got ForLLM: %s, ForUser: %s",
|
||||
result.ForLLM,
|
||||
result.ForUser,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,61 @@
|
||||
//go:build !windows
|
||||
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func processExists(pid int) bool {
|
||||
if pid <= 0 {
|
||||
return false
|
||||
}
|
||||
err := syscall.Kill(pid, 0)
|
||||
return err == nil || err == syscall.EPERM
|
||||
}
|
||||
|
||||
func TestShellTool_TimeoutKillsChildProcess(t *testing.T) {
|
||||
tool := NewExecTool(t.TempDir(), false)
|
||||
tool.SetTimeout(500 * time.Millisecond)
|
||||
|
||||
args := map[string]any{
|
||||
// Spawn a child process that would outlive the shell unless process-group kill is used.
|
||||
"command": "sleep 60 & echo $! > child.pid; wait",
|
||||
}
|
||||
|
||||
result := tool.Execute(context.Background(), args)
|
||||
if !result.IsError {
|
||||
t.Fatalf("expected timeout error, got success: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "timed out") {
|
||||
t.Fatalf("expected timeout message, got: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
childPIDPath := filepath.Join(tool.workingDir, "child.pid")
|
||||
data, err := os.ReadFile(childPIDPath)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read child pid file: %v", err)
|
||||
}
|
||||
|
||||
childPID, err := strconv.Atoi(strings.TrimSpace(string(data)))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse child pid: %v", err)
|
||||
}
|
||||
|
||||
deadline := time.Now().Add(2 * time.Second)
|
||||
for time.Now().Before(deadline) {
|
||||
if !processExists(childPID) {
|
||||
return
|
||||
}
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
}
|
||||
|
||||
t.Fatalf("child process %d is still running after timeout", childPID)
|
||||
}
|
||||
@@ -0,0 +1,201 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/skills"
|
||||
"github.com/sipeed/picoclaw/pkg/utils"
|
||||
)
|
||||
|
||||
// InstallSkillTool allows the LLM agent to install skills from registries.
|
||||
// It shares the same RegistryManager that FindSkillsTool uses,
|
||||
// so all registries configured in config are available for installation.
|
||||
type InstallSkillTool struct {
|
||||
registryMgr *skills.RegistryManager
|
||||
workspace string
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// NewInstallSkillTool creates a new InstallSkillTool.
|
||||
// registryMgr is the shared registry manager (same instance as FindSkillsTool).
|
||||
// workspace is the root workspace directory; skills install to {workspace}/skills/{slug}/.
|
||||
func NewInstallSkillTool(registryMgr *skills.RegistryManager, workspace string) *InstallSkillTool {
|
||||
return &InstallSkillTool{
|
||||
registryMgr: registryMgr,
|
||||
workspace: workspace,
|
||||
mu: sync.Mutex{},
|
||||
}
|
||||
}
|
||||
|
||||
func (t *InstallSkillTool) Name() string {
|
||||
return "install_skill"
|
||||
}
|
||||
|
||||
func (t *InstallSkillTool) Description() string {
|
||||
return "Install a skill from a registry by slug. Downloads and extracts the skill into the workspace. Use find_skills first to discover available skills."
|
||||
}
|
||||
|
||||
func (t *InstallSkillTool) Parameters() map[string]any {
|
||||
return map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"slug": map[string]any{
|
||||
"type": "string",
|
||||
"description": "The unique slug of the skill to install (e.g., 'github', 'docker-compose')",
|
||||
},
|
||||
"version": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Specific version to install (optional, defaults to latest)",
|
||||
},
|
||||
"registry": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Registry to install from (required, e.g., 'clawhub')",
|
||||
},
|
||||
"force": map[string]any{
|
||||
"type": "boolean",
|
||||
"description": "Force reinstall if skill already exists (default false)",
|
||||
},
|
||||
},
|
||||
"required": []string{"slug", "registry"},
|
||||
}
|
||||
}
|
||||
|
||||
func (t *InstallSkillTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
|
||||
// Install lock to prevent concurrent directory operations.
|
||||
// Ideally this should be done at a `slug` level, currently, its at a `workspace` level.
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
// Validate slug
|
||||
slug, _ := args["slug"].(string)
|
||||
if err := utils.ValidateSkillIdentifier(slug); err != nil {
|
||||
return ErrorResult(fmt.Sprintf("invalid slug %q: error: %s", slug, err.Error()))
|
||||
}
|
||||
|
||||
// Validate registry
|
||||
registryName, _ := args["registry"].(string)
|
||||
if err := utils.ValidateSkillIdentifier(registryName); err != nil {
|
||||
return ErrorResult(fmt.Sprintf("invalid registry %q: error: %s", registryName, err.Error()))
|
||||
}
|
||||
|
||||
version, _ := args["version"].(string)
|
||||
force, _ := args["force"].(bool)
|
||||
|
||||
// Check if already installed.
|
||||
skillsDir := filepath.Join(t.workspace, "skills")
|
||||
targetDir := filepath.Join(skillsDir, slug)
|
||||
|
||||
if !force {
|
||||
if _, err := os.Stat(targetDir); err == nil {
|
||||
return ErrorResult(
|
||||
fmt.Sprintf("skill %q already installed at %s. Use force=true to reinstall.", slug, targetDir),
|
||||
)
|
||||
}
|
||||
} else {
|
||||
// Force: remove existing if present.
|
||||
os.RemoveAll(targetDir)
|
||||
}
|
||||
|
||||
// Resolve which registry to use.
|
||||
registry := t.registryMgr.GetRegistry(registryName)
|
||||
if registry == nil {
|
||||
return ErrorResult(fmt.Sprintf("registry %q not found", registryName))
|
||||
}
|
||||
|
||||
// Ensure skills directory exists.
|
||||
if err := os.MkdirAll(skillsDir, 0o755); err != nil {
|
||||
return ErrorResult(fmt.Sprintf("failed to create skills directory: %v", err))
|
||||
}
|
||||
|
||||
// Download and install (handles metadata, version resolution, extraction).
|
||||
result, err := registry.DownloadAndInstall(ctx, slug, version, targetDir)
|
||||
if err != nil {
|
||||
// Clean up partial install.
|
||||
rmErr := os.RemoveAll(targetDir)
|
||||
if rmErr != nil {
|
||||
logger.ErrorCF("tool", "Failed to remove partial install",
|
||||
map[string]any{
|
||||
"tool": "install_skill",
|
||||
"target_dir": targetDir,
|
||||
"error": rmErr.Error(),
|
||||
})
|
||||
}
|
||||
return ErrorResult(fmt.Sprintf("failed to install %q: %v", slug, err))
|
||||
}
|
||||
|
||||
// Moderation: block malware.
|
||||
if result.IsMalwareBlocked {
|
||||
rmErr := os.RemoveAll(targetDir)
|
||||
if rmErr != nil {
|
||||
logger.ErrorCF("tool", "Failed to remove partial install",
|
||||
map[string]any{
|
||||
"tool": "install_skill",
|
||||
"target_dir": targetDir,
|
||||
"error": rmErr.Error(),
|
||||
})
|
||||
}
|
||||
return ErrorResult(fmt.Sprintf("skill %q is flagged as malicious and cannot be installed", slug))
|
||||
}
|
||||
|
||||
// Write origin metadata.
|
||||
if err := writeOriginMeta(targetDir, registry.Name(), slug, result.Version); err != nil {
|
||||
logger.ErrorCF("tool", "Failed to write origin metadata",
|
||||
map[string]any{
|
||||
"tool": "install_skill",
|
||||
"error": err.Error(),
|
||||
"target": targetDir,
|
||||
"registry": registry.Name(),
|
||||
"slug": slug,
|
||||
"version": result.Version,
|
||||
})
|
||||
_ = err
|
||||
}
|
||||
|
||||
// Build result with moderation warning if suspicious.
|
||||
var output string
|
||||
if result.IsSuspicious {
|
||||
output = fmt.Sprintf("⚠️ Warning: skill %q is flagged as suspicious (may contain risky patterns).\n\n", slug)
|
||||
}
|
||||
output += fmt.Sprintf("Successfully installed skill %q v%s from %s registry.\nLocation: %s\n",
|
||||
slug, result.Version, registry.Name(), targetDir)
|
||||
|
||||
if result.Summary != "" {
|
||||
output += fmt.Sprintf("Description: %s\n", result.Summary)
|
||||
}
|
||||
output += "\nThe skill is now available and can be loaded in the current session."
|
||||
|
||||
return SilentResult(output)
|
||||
}
|
||||
|
||||
// originMeta tracks which registry a skill was installed from.
|
||||
type originMeta struct {
|
||||
Version int `json:"version"`
|
||||
Registry string `json:"registry"`
|
||||
Slug string `json:"slug"`
|
||||
InstalledVersion string `json:"installed_version"`
|
||||
InstalledAt int64 `json:"installed_at"`
|
||||
}
|
||||
|
||||
func writeOriginMeta(targetDir, registryName, slug, version string) error {
|
||||
meta := originMeta{
|
||||
Version: 1,
|
||||
Registry: registryName,
|
||||
Slug: slug,
|
||||
InstalledVersion: version,
|
||||
InstalledAt: time.Now().UnixMilli(),
|
||||
}
|
||||
|
||||
data, err := json.MarshalIndent(meta, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return os.WriteFile(filepath.Join(targetDir, ".skill-origin.json"), data, 0o644)
|
||||
}
|
||||
@@ -0,0 +1,104 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/skills"
|
||||
)
|
||||
|
||||
func TestInstallSkillToolName(t *testing.T) {
|
||||
tool := NewInstallSkillTool(skills.NewRegistryManager(), t.TempDir())
|
||||
assert.Equal(t, "install_skill", tool.Name())
|
||||
}
|
||||
|
||||
func TestInstallSkillToolMissingSlug(t *testing.T) {
|
||||
tool := NewInstallSkillTool(skills.NewRegistryManager(), t.TempDir())
|
||||
result := tool.Execute(context.Background(), map[string]any{})
|
||||
assert.True(t, result.IsError)
|
||||
assert.Contains(t, result.ForLLM, "identifier is required and must be a non-empty string")
|
||||
}
|
||||
|
||||
func TestInstallSkillToolEmptySlug(t *testing.T) {
|
||||
tool := NewInstallSkillTool(skills.NewRegistryManager(), t.TempDir())
|
||||
result := tool.Execute(context.Background(), map[string]any{
|
||||
"slug": " ",
|
||||
})
|
||||
assert.True(t, result.IsError)
|
||||
assert.Contains(t, result.ForLLM, "identifier is required and must be a non-empty string")
|
||||
}
|
||||
|
||||
func TestInstallSkillToolUnsafeSlug(t *testing.T) {
|
||||
tool := NewInstallSkillTool(skills.NewRegistryManager(), t.TempDir())
|
||||
|
||||
cases := []string{
|
||||
"../etc/passwd",
|
||||
"path/traversal",
|
||||
"path\\traversal",
|
||||
}
|
||||
|
||||
for _, slug := range cases {
|
||||
result := tool.Execute(context.Background(), map[string]any{
|
||||
"slug": slug,
|
||||
})
|
||||
assert.True(t, result.IsError, "slug %q should be rejected", slug)
|
||||
assert.Contains(t, result.ForLLM, "invalid slug")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInstallSkillToolAlreadyExists(t *testing.T) {
|
||||
workspace := t.TempDir()
|
||||
skillDir := filepath.Join(workspace, "skills", "existing-skill")
|
||||
require.NoError(t, os.MkdirAll(skillDir, 0o755))
|
||||
|
||||
tool := NewInstallSkillTool(skills.NewRegistryManager(), workspace)
|
||||
result := tool.Execute(context.Background(), map[string]any{
|
||||
"slug": "existing-skill",
|
||||
"registry": "clawhub",
|
||||
})
|
||||
assert.True(t, result.IsError)
|
||||
assert.Contains(t, result.ForLLM, "already installed")
|
||||
}
|
||||
|
||||
func TestInstallSkillToolRegistryNotFound(t *testing.T) {
|
||||
workspace := t.TempDir()
|
||||
tool := NewInstallSkillTool(skills.NewRegistryManager(), workspace)
|
||||
result := tool.Execute(context.Background(), map[string]any{
|
||||
"slug": "some-skill",
|
||||
"registry": "nonexistent",
|
||||
})
|
||||
assert.True(t, result.IsError)
|
||||
assert.Contains(t, result.ForLLM, "registry")
|
||||
assert.Contains(t, result.ForLLM, "not found")
|
||||
}
|
||||
|
||||
func TestInstallSkillToolParameters(t *testing.T) {
|
||||
tool := NewInstallSkillTool(skills.NewRegistryManager(), t.TempDir())
|
||||
params := tool.Parameters()
|
||||
|
||||
props, ok := params["properties"].(map[string]any)
|
||||
assert.True(t, ok)
|
||||
assert.Contains(t, props, "slug")
|
||||
assert.Contains(t, props, "version")
|
||||
assert.Contains(t, props, "registry")
|
||||
assert.Contains(t, props, "force")
|
||||
|
||||
required, ok := params["required"].([]string)
|
||||
assert.True(t, ok)
|
||||
assert.Contains(t, required, "slug")
|
||||
assert.Contains(t, required, "registry")
|
||||
}
|
||||
|
||||
func TestInstallSkillToolMissingRegistry(t *testing.T) {
|
||||
tool := NewInstallSkillTool(skills.NewRegistryManager(), t.TempDir())
|
||||
result := tool.Execute(context.Background(), map[string]any{
|
||||
"slug": "some-skill",
|
||||
})
|
||||
assert.True(t, result.IsError)
|
||||
assert.Contains(t, result.ForLLM, "invalid registry")
|
||||
}
|
||||
@@ -0,0 +1,119 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/skills"
|
||||
)
|
||||
|
||||
// FindSkillsTool allows the LLM agent to search for installable skills from registries.
|
||||
type FindSkillsTool struct {
|
||||
registryMgr *skills.RegistryManager
|
||||
cache *skills.SearchCache
|
||||
}
|
||||
|
||||
// NewFindSkillsTool creates a new FindSkillsTool.
|
||||
// registryMgr is the shared registry manager (built from config in createToolRegistry).
|
||||
// cache is the search cache for deduplicating similar queries.
|
||||
func NewFindSkillsTool(registryMgr *skills.RegistryManager, cache *skills.SearchCache) *FindSkillsTool {
|
||||
return &FindSkillsTool{
|
||||
registryMgr: registryMgr,
|
||||
cache: cache,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *FindSkillsTool) Name() string {
|
||||
return "find_skills"
|
||||
}
|
||||
|
||||
func (t *FindSkillsTool) Description() string {
|
||||
return "Search for installable skills from skill registries. Returns skill slugs, descriptions, versions, and relevance scores. Use this to discover skills before installing them with install_skill."
|
||||
}
|
||||
|
||||
func (t *FindSkillsTool) Parameters() map[string]any {
|
||||
return map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"query": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Search query describing the desired skill capability (e.g., 'github integration', 'database management')",
|
||||
},
|
||||
"limit": map[string]any{
|
||||
"type": "integer",
|
||||
"description": "Maximum number of results to return (1-20, default 5)",
|
||||
"minimum": 1.0,
|
||||
"maximum": 20.0,
|
||||
},
|
||||
},
|
||||
"required": []string{"query"},
|
||||
}
|
||||
}
|
||||
|
||||
func (t *FindSkillsTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
|
||||
query, ok := args["query"].(string)
|
||||
query = strings.ToLower(strings.TrimSpace(query))
|
||||
if !ok || query == "" {
|
||||
return ErrorResult("query is required and must be a non-empty string")
|
||||
}
|
||||
|
||||
limit := 5
|
||||
if l, ok := args["limit"].(float64); ok {
|
||||
li := int(l)
|
||||
if li >= 1 && li <= 20 {
|
||||
limit = li
|
||||
}
|
||||
}
|
||||
|
||||
// Check cache first.
|
||||
if t.cache != nil {
|
||||
if cached, hit := t.cache.Get(query); hit {
|
||||
return SilentResult(formatSearchResults(query, cached, true))
|
||||
}
|
||||
}
|
||||
|
||||
// Search all registries.
|
||||
results, err := t.registryMgr.SearchAll(ctx, query, limit)
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("skill search failed: %v", err))
|
||||
}
|
||||
|
||||
// Cache the results.
|
||||
if t.cache != nil && len(results) > 0 {
|
||||
t.cache.Put(query, results)
|
||||
}
|
||||
|
||||
return SilentResult(formatSearchResults(query, results, false))
|
||||
}
|
||||
|
||||
func formatSearchResults(query string, results []skills.SearchResult, cached bool) string {
|
||||
if len(results) == 0 {
|
||||
return fmt.Sprintf("No skills found for query: %q", query)
|
||||
}
|
||||
|
||||
var sb strings.Builder
|
||||
source := ""
|
||||
if cached {
|
||||
source = " (cached)"
|
||||
}
|
||||
sb.WriteString(fmt.Sprintf("Found %d skills for %q%s:\n\n", len(results), query, source))
|
||||
|
||||
for i, r := range results {
|
||||
sb.WriteString(fmt.Sprintf("%d. **%s**", i+1, r.Slug))
|
||||
if r.Version != "" {
|
||||
sb.WriteString(fmt.Sprintf(" v%s", r.Version))
|
||||
}
|
||||
sb.WriteString(fmt.Sprintf(" (score: %.3f, registry: %s)\n", r.Score, r.RegistryName))
|
||||
if r.DisplayName != "" && r.DisplayName != r.Slug {
|
||||
sb.WriteString(fmt.Sprintf(" Name: %s\n", r.DisplayName))
|
||||
}
|
||||
if r.Summary != "" {
|
||||
sb.WriteString(fmt.Sprintf(" %s\n", r.Summary))
|
||||
}
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
|
||||
sb.WriteString("Use install_skill with the slug to install a skill.")
|
||||
return sb.String()
|
||||
}
|
||||
@@ -0,0 +1,90 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/skills"
|
||||
)
|
||||
|
||||
func TestFindSkillsToolName(t *testing.T) {
|
||||
tool := NewFindSkillsTool(skills.NewRegistryManager(), nil)
|
||||
assert.Equal(t, "find_skills", tool.Name())
|
||||
}
|
||||
|
||||
func TestFindSkillsToolMissingQuery(t *testing.T) {
|
||||
tool := NewFindSkillsTool(skills.NewRegistryManager(), nil)
|
||||
result := tool.Execute(context.Background(), map[string]any{})
|
||||
assert.True(t, result.IsError)
|
||||
assert.Contains(t, result.ForLLM, "query is required")
|
||||
}
|
||||
|
||||
func TestFindSkillsToolEmptyQuery(t *testing.T) {
|
||||
tool := NewFindSkillsTool(skills.NewRegistryManager(), nil)
|
||||
result := tool.Execute(context.Background(), map[string]any{
|
||||
"query": " ",
|
||||
})
|
||||
assert.True(t, result.IsError)
|
||||
}
|
||||
|
||||
func TestFindSkillsToolCacheHit(t *testing.T) {
|
||||
cache := skills.NewSearchCache(10, 5*60*1000*1000*1000) // 5 min
|
||||
cache.Put("github", []skills.SearchResult{
|
||||
{Slug: "github", Score: 0.9, RegistryName: "clawhub"},
|
||||
})
|
||||
|
||||
tool := NewFindSkillsTool(skills.NewRegistryManager(), cache)
|
||||
result := tool.Execute(context.Background(), map[string]any{
|
||||
"query": "github",
|
||||
})
|
||||
|
||||
assert.False(t, result.IsError)
|
||||
assert.Contains(t, result.ForLLM, "github")
|
||||
assert.Contains(t, result.ForLLM, "cached")
|
||||
}
|
||||
|
||||
func TestFindSkillsToolParameters(t *testing.T) {
|
||||
tool := NewFindSkillsTool(skills.NewRegistryManager(), nil)
|
||||
params := tool.Parameters()
|
||||
|
||||
props, ok := params["properties"].(map[string]any)
|
||||
assert.True(t, ok)
|
||||
assert.Contains(t, props, "query")
|
||||
assert.Contains(t, props, "limit")
|
||||
|
||||
required, ok := params["required"].([]string)
|
||||
assert.True(t, ok)
|
||||
assert.Contains(t, required, "query")
|
||||
}
|
||||
|
||||
func TestFindSkillsToolDescription(t *testing.T) {
|
||||
tool := NewFindSkillsTool(skills.NewRegistryManager(), nil)
|
||||
assert.NotEmpty(t, tool.Description())
|
||||
assert.Contains(t, tool.Description(), "skill")
|
||||
}
|
||||
|
||||
func TestFormatSearchResultsEmpty(t *testing.T) {
|
||||
result := formatSearchResults("test query", nil, false)
|
||||
assert.Contains(t, result, "No skills found")
|
||||
}
|
||||
|
||||
func TestFormatSearchResultsWithData(t *testing.T) {
|
||||
results := []skills.SearchResult{
|
||||
{
|
||||
Slug: "github",
|
||||
Score: 0.95,
|
||||
DisplayName: "GitHub",
|
||||
Summary: "GitHub API integration",
|
||||
Version: "1.0.0",
|
||||
RegistryName: "clawhub",
|
||||
},
|
||||
}
|
||||
output := formatSearchResults("github", results, false)
|
||||
assert.Contains(t, output, "github")
|
||||
assert.Contains(t, output, "v1.0.0")
|
||||
assert.Contains(t, output, "0.950")
|
||||
assert.Contains(t, output, "clawhub")
|
||||
assert.Contains(t, output, "install_skill")
|
||||
}
|
||||
+10
-9
@@ -3,6 +3,7 @@ package tools
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type SpawnTool struct {
|
||||
@@ -34,19 +35,19 @@ func (t *SpawnTool) Description() string {
|
||||
return "Spawn a subagent to handle a task in the background. Use this for complex or time-consuming tasks that can run independently. The subagent will complete the task and report back when done."
|
||||
}
|
||||
|
||||
func (t *SpawnTool) Parameters() map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
func (t *SpawnTool) Parameters() map[string]any {
|
||||
return map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"task": map[string]interface{}{
|
||||
"properties": map[string]any{
|
||||
"task": map[string]any{
|
||||
"type": "string",
|
||||
"description": "The task for subagent to complete",
|
||||
},
|
||||
"label": map[string]interface{}{
|
||||
"label": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Optional short label for the task (for display)",
|
||||
},
|
||||
"agent_id": map[string]interface{}{
|
||||
"agent_id": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Optional target agent ID to delegate the task to",
|
||||
},
|
||||
@@ -64,10 +65,10 @@ func (t *SpawnTool) SetAllowlistChecker(check func(targetAgentID string) bool) {
|
||||
t.allowlistCheck = check
|
||||
}
|
||||
|
||||
func (t *SpawnTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
|
||||
func (t *SpawnTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
|
||||
task, ok := args["task"].(string)
|
||||
if !ok {
|
||||
return ErrorResult("task is required")
|
||||
if !ok || strings.TrimSpace(task) == "" {
|
||||
return ErrorResult("task is required and must be a non-empty string")
|
||||
}
|
||||
|
||||
label, _ := args["label"].(string)
|
||||
|
||||
@@ -0,0 +1,79 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSpawnTool_Execute_EmptyTask(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test", nil)
|
||||
tool := NewSpawnTool(manager)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
args map[string]any
|
||||
}{
|
||||
{"empty string", map[string]any{"task": ""}},
|
||||
{"whitespace only", map[string]any{"task": " "}},
|
||||
{"tabs and newlines", map[string]any{"task": "\t\n "}},
|
||||
{"missing task key", map[string]any{"label": "test"}},
|
||||
{"wrong type", map[string]any{"task": 123}},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := tool.Execute(ctx, tt.args)
|
||||
if result == nil {
|
||||
t.Fatal("Result should not be nil")
|
||||
}
|
||||
if !result.IsError {
|
||||
t.Error("Expected error for invalid task parameter")
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "task is required") {
|
||||
t.Errorf("Error message should mention 'task is required', got: %s", result.ForLLM)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpawnTool_Execute_ValidTask(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test", nil)
|
||||
tool := NewSpawnTool(manager)
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]any{
|
||||
"task": "Write a haiku about coding",
|
||||
"label": "haiku-task",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
if result == nil {
|
||||
t.Fatal("Result should not be nil")
|
||||
}
|
||||
if result.IsError {
|
||||
t.Errorf("Expected success for valid task, got error: %s", result.ForLLM)
|
||||
}
|
||||
if !result.Async {
|
||||
t.Error("SpawnTool should return async result")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpawnTool_Execute_NilManager(t *testing.T) {
|
||||
tool := NewSpawnTool(nil)
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]any{"task": "test task"}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
if !result.IsError {
|
||||
t.Error("Expected error for nil manager")
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "Subagent manager not configured") {
|
||||
t.Errorf("Error message should mention manager not configured, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
+21
-15
@@ -24,41 +24,41 @@ func (t *SPITool) Description() string {
|
||||
return "Interact with SPI bus devices for high-speed peripheral communication. Actions: list (find SPI devices), transfer (full-duplex send/receive), read (receive bytes). Linux only."
|
||||
}
|
||||
|
||||
func (t *SPITool) Parameters() map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
func (t *SPITool) Parameters() map[string]any {
|
||||
return map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"action": map[string]interface{}{
|
||||
"properties": map[string]any{
|
||||
"action": map[string]any{
|
||||
"type": "string",
|
||||
"enum": []string{"list", "transfer", "read"},
|
||||
"description": "Action to perform: list (find available SPI devices), transfer (full-duplex send/receive), read (receive bytes by sending zeros)",
|
||||
},
|
||||
"device": map[string]interface{}{
|
||||
"device": map[string]any{
|
||||
"type": "string",
|
||||
"description": "SPI device identifier (e.g. \"2.0\" for /dev/spidev2.0). Required for transfer/read.",
|
||||
},
|
||||
"speed": map[string]interface{}{
|
||||
"speed": map[string]any{
|
||||
"type": "integer",
|
||||
"description": "SPI clock speed in Hz. Default: 1000000 (1 MHz).",
|
||||
},
|
||||
"mode": map[string]interface{}{
|
||||
"mode": map[string]any{
|
||||
"type": "integer",
|
||||
"description": "SPI mode (0-3). Default: 0. Mode sets CPOL and CPHA: 0=0,0 1=0,1 2=1,0 3=1,1.",
|
||||
},
|
||||
"bits": map[string]interface{}{
|
||||
"bits": map[string]any{
|
||||
"type": "integer",
|
||||
"description": "Bits per word. Default: 8.",
|
||||
},
|
||||
"data": map[string]interface{}{
|
||||
"data": map[string]any{
|
||||
"type": "array",
|
||||
"items": map[string]interface{}{"type": "integer"},
|
||||
"items": map[string]any{"type": "integer"},
|
||||
"description": "Bytes to send (0-255 each). Required for transfer action.",
|
||||
},
|
||||
"length": map[string]interface{}{
|
||||
"length": map[string]any{
|
||||
"type": "integer",
|
||||
"description": "Number of bytes to read (1-4096). Required for read action.",
|
||||
},
|
||||
"confirm": map[string]interface{}{
|
||||
"confirm": map[string]any{
|
||||
"type": "boolean",
|
||||
"description": "Must be true for transfer operations. Safety guard to prevent accidental writes.",
|
||||
},
|
||||
@@ -67,7 +67,7 @@ func (t *SPITool) Parameters() map[string]interface{} {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *SPITool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
|
||||
func (t *SPITool) Execute(ctx context.Context, args map[string]any) *ToolResult {
|
||||
if runtime.GOOS != "linux" {
|
||||
return ErrorResult("SPI is only supported on Linux. This tool requires /dev/spidev* device files.")
|
||||
}
|
||||
@@ -97,7 +97,9 @@ func (t *SPITool) list() *ToolResult {
|
||||
}
|
||||
|
||||
if len(matches) == 0 {
|
||||
return SilentResult("No SPI devices found. You may need to:\n1. Enable SPI in device tree\n2. Configure pinmux for your board (see hardware skill)\n3. Check that spidev module is loaded")
|
||||
return SilentResult(
|
||||
"No SPI devices found. You may need to:\n1. Enable SPI in device tree\n2. Configure pinmux for your board (see hardware skill)\n3. Check that spidev module is loaded",
|
||||
)
|
||||
}
|
||||
|
||||
type devInfo struct {
|
||||
@@ -117,8 +119,12 @@ func (t *SPITool) list() *ToolResult {
|
||||
return SilentResult(fmt.Sprintf("Found %d SPI device(s):\n%s", len(devices), string(result)))
|
||||
}
|
||||
|
||||
// Helper function for SPI operations (used by platform-specific implementations)
|
||||
|
||||
// parseSPIArgs extracts and validates common SPI parameters
|
||||
func parseSPIArgs(args map[string]interface{}) (device string, speed uint32, mode uint8, bits uint8, errMsg string) {
|
||||
//
|
||||
//nolint:unused // Used by spi_linux.go
|
||||
func parseSPIArgs(args map[string]any) (device string, speed uint32, mode uint8, bits uint8, errMsg string) {
|
||||
dev, ok := args["device"].(string)
|
||||
if !ok || dev == "" {
|
||||
return "", 0, 0, 0, "device is required (e.g. \"2.0\" for /dev/spidev2.0)"
|
||||
|
||||
@@ -66,10 +66,12 @@ func configureSPI(devPath string, mode uint8, bits uint8, speed uint32) (int, *T
|
||||
}
|
||||
|
||||
// transfer performs a full-duplex SPI transfer
|
||||
func (t *SPITool) transfer(args map[string]interface{}) *ToolResult {
|
||||
func (t *SPITool) transfer(args map[string]any) *ToolResult {
|
||||
confirm, _ := args["confirm"].(bool)
|
||||
if !confirm {
|
||||
return ErrorResult("transfer operations require confirm: true. Please confirm with the user before sending data to SPI devices.")
|
||||
return ErrorResult(
|
||||
"transfer operations require confirm: true. Please confirm with the user before sending data to SPI devices.",
|
||||
)
|
||||
}
|
||||
|
||||
dev, speed, mode, bits, errMsg := parseSPIArgs(args)
|
||||
@@ -77,7 +79,7 @@ func (t *SPITool) transfer(args map[string]interface{}) *ToolResult {
|
||||
return ErrorResult(errMsg)
|
||||
}
|
||||
|
||||
dataRaw, ok := args["data"].([]interface{})
|
||||
dataRaw, ok := args["data"].([]any)
|
||||
if !ok || len(dataRaw) == 0 {
|
||||
return ErrorResult("data is required for transfer (array of byte values 0-255)")
|
||||
}
|
||||
@@ -130,7 +132,7 @@ func (t *SPITool) transfer(args map[string]interface{}) *ToolResult {
|
||||
intBytes[i] = int(b)
|
||||
}
|
||||
|
||||
result, _ := json.MarshalIndent(map[string]interface{}{
|
||||
result, _ := json.MarshalIndent(map[string]any{
|
||||
"device": devPath,
|
||||
"sent": len(txBuf),
|
||||
"received": intBytes,
|
||||
@@ -140,7 +142,7 @@ func (t *SPITool) transfer(args map[string]interface{}) *ToolResult {
|
||||
}
|
||||
|
||||
// readDevice reads bytes from SPI by sending zeros (read-only, no confirm needed)
|
||||
func (t *SPITool) readDevice(args map[string]interface{}) *ToolResult {
|
||||
func (t *SPITool) readDevice(args map[string]any) *ToolResult {
|
||||
dev, speed, mode, bits, errMsg := parseSPIArgs(args)
|
||||
if errMsg != "" {
|
||||
return ErrorResult(errMsg)
|
||||
@@ -186,7 +188,7 @@ func (t *SPITool) readDevice(args map[string]interface{}) *ToolResult {
|
||||
intBytes[i] = int(b)
|
||||
}
|
||||
|
||||
result, _ := json.MarshalIndent(map[string]interface{}{
|
||||
result, _ := json.MarshalIndent(map[string]any{
|
||||
"device": devPath,
|
||||
"bytes": intBytes,
|
||||
"hex": hexBytes,
|
||||
|
||||
@@ -3,11 +3,11 @@
|
||||
package tools
|
||||
|
||||
// transfer is a stub for non-Linux platforms.
|
||||
func (t *SPITool) transfer(args map[string]interface{}) *ToolResult {
|
||||
func (t *SPITool) transfer(args map[string]any) *ToolResult {
|
||||
return ErrorResult("SPI is only supported on Linux")
|
||||
}
|
||||
|
||||
// readDevice is a stub for non-Linux platforms.
|
||||
func (t *SPITool) readDevice(args map[string]interface{}) *ToolResult {
|
||||
func (t *SPITool) readDevice(args map[string]any) *ToolResult {
|
||||
return ErrorResult("SPI is only supported on Linux")
|
||||
}
|
||||
|
||||
+28
-15
@@ -38,7 +38,11 @@ type SubagentManager struct {
|
||||
nextID int
|
||||
}
|
||||
|
||||
func NewSubagentManager(provider providers.LLMProvider, defaultModel, workspace string, bus *bus.MessageBus) *SubagentManager {
|
||||
func NewSubagentManager(
|
||||
provider providers.LLMProvider,
|
||||
defaultModel, workspace string,
|
||||
bus *bus.MessageBus,
|
||||
) *SubagentManager {
|
||||
return &SubagentManager{
|
||||
tasks: make(map[string]*SubagentTask),
|
||||
provider: provider,
|
||||
@@ -76,7 +80,11 @@ func (sm *SubagentManager) RegisterTool(tool Tool) {
|
||||
sm.tools.Register(tool)
|
||||
}
|
||||
|
||||
func (sm *SubagentManager) Spawn(ctx context.Context, task, label, agentID, originChannel, originChatID string, callback AsyncCallback) (string, error) {
|
||||
func (sm *SubagentManager) Spawn(
|
||||
ctx context.Context,
|
||||
task, label, agentID, originChannel, originChatID string,
|
||||
callback AsyncCallback,
|
||||
) (string, error) {
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
|
||||
@@ -124,12 +132,12 @@ After completing the task, provide a clear summary of what was done.`
|
||||
},
|
||||
}
|
||||
|
||||
// Check if context is already cancelled before starting
|
||||
// Check if context is already canceled before starting
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
sm.mu.Lock()
|
||||
task.Status = "cancelled"
|
||||
task.Result = "Task cancelled before execution"
|
||||
task.Status = "canceled"
|
||||
task.Result = "Task canceled before execution"
|
||||
sm.mu.Unlock()
|
||||
return
|
||||
default:
|
||||
@@ -177,10 +185,10 @@ After completing the task, provide a clear summary of what was done.`
|
||||
if err != nil {
|
||||
task.Status = "failed"
|
||||
task.Result = fmt.Sprintf("Error: %v", err)
|
||||
// Check if it was cancelled
|
||||
// Check if it was canceled
|
||||
if ctx.Err() != nil {
|
||||
task.Status = "cancelled"
|
||||
task.Result = "Task cancelled during execution"
|
||||
task.Status = "canceled"
|
||||
task.Result = "Task canceled during execution"
|
||||
}
|
||||
result = &ToolResult{
|
||||
ForLLM: task.Result,
|
||||
@@ -194,7 +202,12 @@ After completing the task, provide a clear summary of what was done.`
|
||||
task.Status = "completed"
|
||||
task.Result = loopResult.Content
|
||||
result = &ToolResult{
|
||||
ForLLM: fmt.Sprintf("Subagent '%s' completed (iterations: %d): %s", task.Label, loopResult.Iterations, loopResult.Content),
|
||||
ForLLM: fmt.Sprintf(
|
||||
"Subagent '%s' completed (iterations: %d): %s",
|
||||
task.Label,
|
||||
loopResult.Iterations,
|
||||
loopResult.Content,
|
||||
),
|
||||
ForUser: loopResult.Content,
|
||||
Silent: false,
|
||||
IsError: false,
|
||||
@@ -258,15 +271,15 @@ func (t *SubagentTool) Description() string {
|
||||
return "Execute a subagent task synchronously and return the result. Use this for delegating specific tasks to an independent agent instance. Returns execution summary to user and full details to LLM."
|
||||
}
|
||||
|
||||
func (t *SubagentTool) Parameters() map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
func (t *SubagentTool) Parameters() map[string]any {
|
||||
return map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"task": map[string]interface{}{
|
||||
"properties": map[string]any{
|
||||
"task": map[string]any{
|
||||
"type": "string",
|
||||
"description": "The task for subagent to complete",
|
||||
},
|
||||
"label": map[string]interface{}{
|
||||
"label": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Optional short label for the task (for display)",
|
||||
},
|
||||
@@ -280,7 +293,7 @@ func (t *SubagentTool) SetContext(channel, chatID string) {
|
||||
t.originChatID = chatID
|
||||
}
|
||||
|
||||
func (t *SubagentTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
|
||||
func (t *SubagentTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
|
||||
task, ok := args["task"].(string)
|
||||
if !ok {
|
||||
return ErrorResult("task is required").WithError(fmt.Errorf("task parameter is required"))
|
||||
|
||||
@@ -11,10 +11,16 @@ import (
|
||||
|
||||
// MockLLMProvider is a test implementation of LLMProvider
|
||||
type MockLLMProvider struct {
|
||||
lastOptions map[string]interface{}
|
||||
lastOptions map[string]any
|
||||
}
|
||||
|
||||
func (m *MockLLMProvider) Chat(ctx context.Context, messages []providers.Message, tools []providers.ToolDefinition, model string, options map[string]interface{}) (*providers.LLMResponse, error) {
|
||||
func (m *MockLLMProvider) Chat(
|
||||
ctx context.Context,
|
||||
messages []providers.Message,
|
||||
tools []providers.ToolDefinition,
|
||||
model string,
|
||||
options map[string]any,
|
||||
) (*providers.LLMResponse, error) {
|
||||
m.lastOptions = options
|
||||
// Find the last user message to generate a response
|
||||
for i := len(messages) - 1; i >= 0; i-- {
|
||||
@@ -47,7 +53,7 @@ func TestSubagentManager_SetLLMOptions_AppliesToRunToolLoop(t *testing.T) {
|
||||
tool.SetContext("cli", "direct")
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{"task": "Do something"}
|
||||
args := map[string]any{"task": "Do something"}
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
if result == nil || result.IsError {
|
||||
@@ -108,13 +114,13 @@ func TestSubagentTool_Parameters(t *testing.T) {
|
||||
}
|
||||
|
||||
// Check properties
|
||||
props, ok := params["properties"].(map[string]interface{})
|
||||
props, ok := params["properties"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("Properties should be a map")
|
||||
}
|
||||
|
||||
// Verify task parameter
|
||||
task, ok := props["task"].(map[string]interface{})
|
||||
task, ok := props["task"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("Task parameter should exist")
|
||||
}
|
||||
@@ -123,7 +129,7 @@ func TestSubagentTool_Parameters(t *testing.T) {
|
||||
}
|
||||
|
||||
// Verify label parameter
|
||||
label, ok := props["label"].(map[string]interface{})
|
||||
label, ok := props["label"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("Label parameter should exist")
|
||||
}
|
||||
@@ -163,7 +169,7 @@ func TestSubagentTool_Execute_Success(t *testing.T) {
|
||||
tool.SetContext("telegram", "chat-123")
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
args := map[string]any{
|
||||
"task": "Write a haiku about coding",
|
||||
"label": "haiku-task",
|
||||
}
|
||||
@@ -218,7 +224,7 @@ func TestSubagentTool_Execute_NoLabel(t *testing.T) {
|
||||
tool := NewSubagentTool(manager)
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
args := map[string]any{
|
||||
"task": "Test task without label",
|
||||
}
|
||||
|
||||
@@ -241,7 +247,7 @@ func TestSubagentTool_Execute_MissingTask(t *testing.T) {
|
||||
tool := NewSubagentTool(manager)
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
args := map[string]any{
|
||||
"label": "test",
|
||||
}
|
||||
|
||||
@@ -268,7 +274,7 @@ func TestSubagentTool_Execute_NilManager(t *testing.T) {
|
||||
tool := NewSubagentTool(nil)
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
args := map[string]any{
|
||||
"task": "test task",
|
||||
}
|
||||
|
||||
@@ -297,7 +303,7 @@ func TestSubagentTool_Execute_ContextPassing(t *testing.T) {
|
||||
tool.SetContext(channel, chatID)
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
args := map[string]any{
|
||||
"task": "Test context passing",
|
||||
}
|
||||
|
||||
@@ -324,7 +330,7 @@ func TestSubagentTool_ForUserTruncation(t *testing.T) {
|
||||
|
||||
// Create a task that will generate long response
|
||||
longTask := strings.Repeat("This is a very long task description. ", 100)
|
||||
args := map[string]interface{}{
|
||||
args := map[string]any{
|
||||
"task": longTask,
|
||||
"label": "long-test",
|
||||
}
|
||||
|
||||
@@ -33,7 +33,12 @@ type ToolLoopResult struct {
|
||||
|
||||
// RunToolLoop executes the LLM + tool call iteration loop.
|
||||
// This is the core agent logic that can be reused by both main agent and subagents.
|
||||
func RunToolLoop(ctx context.Context, config ToolLoopConfig, messages []providers.Message, channel, chatID string) (*ToolLoopResult, error) {
|
||||
func RunToolLoop(
|
||||
ctx context.Context,
|
||||
config ToolLoopConfig,
|
||||
messages []providers.Message,
|
||||
channel, chatID string,
|
||||
) (*ToolLoopResult, error) {
|
||||
iteration := 0
|
||||
var finalContent string
|
||||
|
||||
|
||||
+15
-9
@@ -10,11 +10,11 @@ type Message struct {
|
||||
}
|
||||
|
||||
type ToolCall struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Function *FunctionCall `json:"function,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Arguments map[string]interface{} `json:"arguments,omitempty"`
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Function *FunctionCall `json:"function,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Arguments map[string]any `json:"arguments,omitempty"`
|
||||
}
|
||||
|
||||
type FunctionCall struct {
|
||||
@@ -36,7 +36,13 @@ type UsageInfo struct {
|
||||
}
|
||||
|
||||
type LLMProvider interface {
|
||||
Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error)
|
||||
Chat(
|
||||
ctx context.Context,
|
||||
messages []Message,
|
||||
tools []ToolDefinition,
|
||||
model string,
|
||||
options map[string]any,
|
||||
) (*LLMResponse, error)
|
||||
GetDefaultModel() string
|
||||
}
|
||||
|
||||
@@ -46,7 +52,7 @@ type ToolDefinition struct {
|
||||
}
|
||||
|
||||
type ToolFunctionDefinition struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Parameters map[string]interface{} `json:"parameters"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Parameters map[string]any `json:"parameters"`
|
||||
}
|
||||
|
||||
+211
-40
@@ -1,6 +1,7 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
@@ -16,12 +17,50 @@ const (
|
||||
userAgent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36"
|
||||
)
|
||||
|
||||
// createHTTPClient creates an HTTP client with optional proxy support
|
||||
func createHTTPClient(proxyURL string, timeout time.Duration) (*http.Client, error) {
|
||||
client := &http.Client{
|
||||
Timeout: timeout,
|
||||
Transport: &http.Transport{
|
||||
MaxIdleConns: 10,
|
||||
IdleConnTimeout: 30 * time.Second,
|
||||
DisableCompression: false,
|
||||
TLSHandshakeTimeout: 15 * time.Second,
|
||||
},
|
||||
}
|
||||
|
||||
if proxyURL != "" {
|
||||
proxy, err := url.Parse(proxyURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid proxy URL: %w", err)
|
||||
}
|
||||
scheme := strings.ToLower(proxy.Scheme)
|
||||
switch scheme {
|
||||
case "http", "https", "socks5", "socks5h":
|
||||
default:
|
||||
return nil, fmt.Errorf(
|
||||
"unsupported proxy scheme %q (supported: http, https, socks5, socks5h)",
|
||||
proxy.Scheme,
|
||||
)
|
||||
}
|
||||
if proxy.Host == "" {
|
||||
return nil, fmt.Errorf("invalid proxy URL: missing host")
|
||||
}
|
||||
client.Transport.(*http.Transport).Proxy = http.ProxyURL(proxy)
|
||||
} else {
|
||||
client.Transport.(*http.Transport).Proxy = http.ProxyFromEnvironment
|
||||
}
|
||||
|
||||
return client, nil
|
||||
}
|
||||
|
||||
type SearchProvider interface {
|
||||
Search(ctx context.Context, query string, count int) (string, error)
|
||||
}
|
||||
|
||||
type BraveSearchProvider struct {
|
||||
apiKey string
|
||||
proxy string
|
||||
}
|
||||
|
||||
func (p *BraveSearchProvider) Search(ctx context.Context, query string, count int) (string, error) {
|
||||
@@ -36,7 +75,10 @@ func (p *BraveSearchProvider) Search(ctx context.Context, query string, count in
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("X-Subscription-Token", p.apiKey)
|
||||
|
||||
client := &http.Client{Timeout: 10 * time.Second}
|
||||
client, err := createHTTPClient(p.proxy, 10*time.Second)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create HTTP client: %w", err)
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("request failed: %w", err)
|
||||
@@ -84,7 +126,95 @@ func (p *BraveSearchProvider) Search(ctx context.Context, query string, count in
|
||||
return strings.Join(lines, "\n"), nil
|
||||
}
|
||||
|
||||
type DuckDuckGoSearchProvider struct{}
|
||||
type TavilySearchProvider struct {
|
||||
apiKey string
|
||||
baseURL string
|
||||
proxy string
|
||||
}
|
||||
|
||||
func (p *TavilySearchProvider) Search(ctx context.Context, query string, count int) (string, error) {
|
||||
searchURL := p.baseURL
|
||||
if searchURL == "" {
|
||||
searchURL = "https://api.tavily.com/search"
|
||||
}
|
||||
|
||||
payload := map[string]any{
|
||||
"api_key": p.apiKey,
|
||||
"query": query,
|
||||
"search_depth": "advanced",
|
||||
"include_answer": false,
|
||||
"include_images": false,
|
||||
"include_raw_content": false,
|
||||
"max_results": count,
|
||||
}
|
||||
|
||||
bodyBytes, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to marshal payload: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", searchURL, bytes.NewBuffer(bodyBytes))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", userAgent)
|
||||
|
||||
client, err := createHTTPClient(p.proxy, 10*time.Second)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create HTTP client: %w", err)
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("tavily api error (status %d): %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var searchResp struct {
|
||||
Results []struct {
|
||||
Title string `json:"title"`
|
||||
URL string `json:"url"`
|
||||
Content string `json:"content"`
|
||||
} `json:"results"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(body, &searchResp); err != nil {
|
||||
return "", fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
results := searchResp.Results
|
||||
if len(results) == 0 {
|
||||
return fmt.Sprintf("No results for: %s", query), nil
|
||||
}
|
||||
|
||||
var lines []string
|
||||
lines = append(lines, fmt.Sprintf("Results for: %s (via Tavily)", query))
|
||||
for i, item := range results {
|
||||
if i >= count {
|
||||
break
|
||||
}
|
||||
lines = append(lines, fmt.Sprintf("%d. %s\n %s", i+1, item.Title, item.URL))
|
||||
if item.Content != "" {
|
||||
lines = append(lines, fmt.Sprintf(" %s", item.Content))
|
||||
}
|
||||
}
|
||||
|
||||
return strings.Join(lines, "\n"), nil
|
||||
}
|
||||
|
||||
type DuckDuckGoSearchProvider struct {
|
||||
proxy string
|
||||
}
|
||||
|
||||
func (p *DuckDuckGoSearchProvider) Search(ctx context.Context, query string, count int) (string, error) {
|
||||
searchURL := fmt.Sprintf("https://html.duckduckgo.com/html/?q=%s", url.QueryEscape(query))
|
||||
@@ -96,7 +226,10 @@ func (p *DuckDuckGoSearchProvider) Search(ctx context.Context, query string, cou
|
||||
|
||||
req.Header.Set("User-Agent", userAgent)
|
||||
|
||||
client := &http.Client{Timeout: 10 * time.Second}
|
||||
client, err := createHTTPClient(p.proxy, 10*time.Second)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create HTTP client: %w", err)
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("request failed: %w", err)
|
||||
@@ -178,16 +311,23 @@ func stripTags(content string) string {
|
||||
|
||||
type PerplexitySearchProvider struct {
|
||||
apiKey string
|
||||
proxy string
|
||||
}
|
||||
|
||||
func (p *PerplexitySearchProvider) Search(ctx context.Context, query string, count int) (string, error) {
|
||||
searchURL := "https://api.perplexity.ai/chat/completions"
|
||||
|
||||
payload := map[string]interface{}{
|
||||
payload := map[string]any{
|
||||
"model": "sonar",
|
||||
"messages": []map[string]string{
|
||||
{"role": "system", "content": "You are a search assistant. Provide concise search results with titles, URLs, and brief descriptions in the following format:\n1. Title\n URL\n Description\n\nDo not add extra commentary."},
|
||||
{"role": "user", "content": fmt.Sprintf("Search for: %s. Provide up to %d relevant results.", query, count)},
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a search assistant. Provide concise search results with titles, URLs, and brief descriptions in the following format:\n1. Title\n URL\n Description\n\nDo not add extra commentary.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": fmt.Sprintf("Search for: %s. Provide up to %d relevant results.", query, count),
|
||||
},
|
||||
},
|
||||
"max_tokens": 1000,
|
||||
}
|
||||
@@ -206,7 +346,10 @@ func (p *PerplexitySearchProvider) Search(ctx context.Context, query string, cou
|
||||
req.Header.Set("Authorization", "Bearer "+p.apiKey)
|
||||
req.Header.Set("User-Agent", userAgent)
|
||||
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
client, err := createHTTPClient(p.proxy, 30*time.Second)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create HTTP client: %w", err)
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("request failed: %w", err)
|
||||
@@ -312,6 +455,10 @@ type WebSearchToolOptions struct {
|
||||
BraveAPIKey string
|
||||
BraveMaxResults int
|
||||
BraveEnabled bool
|
||||
TavilyAPIKey string
|
||||
TavilyBaseURL string
|
||||
TavilyMaxResults int
|
||||
TavilyEnabled bool
|
||||
DuckDuckGoMaxResults int
|
||||
DuckDuckGoEnabled bool
|
||||
PerplexityAPIKey string
|
||||
@@ -320,20 +467,21 @@ type WebSearchToolOptions struct {
|
||||
SearXNGBaseURL string
|
||||
SearXNGMaxResults int
|
||||
SearXNGEnabled bool
|
||||
Proxy string
|
||||
}
|
||||
|
||||
func NewWebSearchTool(opts WebSearchToolOptions) *WebSearchTool {
|
||||
var provider SearchProvider
|
||||
maxResults := 5
|
||||
|
||||
// Priority: Perplexity > Brave > SearXNG > DuckDuckGo
|
||||
// Priority: Perplexity > Brave > SearXNG > Tavily > DuckDuckGo
|
||||
if opts.PerplexityEnabled && opts.PerplexityAPIKey != "" {
|
||||
provider = &PerplexitySearchProvider{apiKey: opts.PerplexityAPIKey}
|
||||
provider = &PerplexitySearchProvider{apiKey: opts.PerplexityAPIKey, proxy: opts.Proxy}
|
||||
if opts.PerplexityMaxResults > 0 {
|
||||
maxResults = opts.PerplexityMaxResults
|
||||
}
|
||||
} else if opts.BraveEnabled && opts.BraveAPIKey != "" {
|
||||
provider = &BraveSearchProvider{apiKey: opts.BraveAPIKey}
|
||||
provider = &BraveSearchProvider{apiKey: opts.BraveAPIKey, proxy: opts.Proxy}
|
||||
if opts.BraveMaxResults > 0 {
|
||||
maxResults = opts.BraveMaxResults
|
||||
}
|
||||
@@ -342,8 +490,17 @@ func NewWebSearchTool(opts WebSearchToolOptions) *WebSearchTool {
|
||||
if opts.SearXNGMaxResults > 0 {
|
||||
maxResults = opts.SearXNGMaxResults
|
||||
}
|
||||
} else if opts.TavilyEnabled && opts.TavilyAPIKey != "" {
|
||||
provider = &TavilySearchProvider{
|
||||
apiKey: opts.TavilyAPIKey,
|
||||
baseURL: opts.TavilyBaseURL,
|
||||
proxy: opts.Proxy,
|
||||
}
|
||||
if opts.TavilyMaxResults > 0 {
|
||||
maxResults = opts.TavilyMaxResults
|
||||
}
|
||||
} else if opts.DuckDuckGoEnabled {
|
||||
provider = &DuckDuckGoSearchProvider{}
|
||||
provider = &DuckDuckGoSearchProvider{proxy: opts.Proxy}
|
||||
if opts.DuckDuckGoMaxResults > 0 {
|
||||
maxResults = opts.DuckDuckGoMaxResults
|
||||
}
|
||||
@@ -365,15 +522,15 @@ func (t *WebSearchTool) Description() string {
|
||||
return "Search the web for current information. Returns titles, URLs, and snippets from search results."
|
||||
}
|
||||
|
||||
func (t *WebSearchTool) Parameters() map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
func (t *WebSearchTool) Parameters() map[string]any {
|
||||
return map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"query": map[string]interface{}{
|
||||
"properties": map[string]any{
|
||||
"query": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Search query",
|
||||
},
|
||||
"count": map[string]interface{}{
|
||||
"count": map[string]any{
|
||||
"type": "integer",
|
||||
"description": "Number of results (1-10)",
|
||||
"minimum": 1.0,
|
||||
@@ -384,7 +541,7 @@ func (t *WebSearchTool) Parameters() map[string]interface{} {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *WebSearchTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
|
||||
func (t *WebSearchTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
|
||||
query, ok := args["query"].(string)
|
||||
if !ok {
|
||||
return ErrorResult("query is required")
|
||||
@@ -410,6 +567,7 @@ func (t *WebSearchTool) Execute(ctx context.Context, args map[string]interface{}
|
||||
|
||||
type WebFetchTool struct {
|
||||
maxChars int
|
||||
proxy string
|
||||
}
|
||||
|
||||
func NewWebFetchTool(maxChars int) *WebFetchTool {
|
||||
@@ -421,6 +579,16 @@ func NewWebFetchTool(maxChars int) *WebFetchTool {
|
||||
}
|
||||
}
|
||||
|
||||
func NewWebFetchToolWithProxy(maxChars int, proxy string) *WebFetchTool {
|
||||
if maxChars <= 0 {
|
||||
maxChars = 50000
|
||||
}
|
||||
return &WebFetchTool{
|
||||
maxChars: maxChars,
|
||||
proxy: proxy,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *WebFetchTool) Name() string {
|
||||
return "web_fetch"
|
||||
}
|
||||
@@ -429,15 +597,15 @@ func (t *WebFetchTool) Description() string {
|
||||
return "Fetch a URL and extract readable content (HTML to text). Use this to get weather info, news, articles, or any web content."
|
||||
}
|
||||
|
||||
func (t *WebFetchTool) Parameters() map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
func (t *WebFetchTool) Parameters() map[string]any {
|
||||
return map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"url": map[string]interface{}{
|
||||
"properties": map[string]any{
|
||||
"url": map[string]any{
|
||||
"type": "string",
|
||||
"description": "URL to fetch",
|
||||
},
|
||||
"maxChars": map[string]interface{}{
|
||||
"maxChars": map[string]any{
|
||||
"type": "integer",
|
||||
"description": "Maximum characters to extract",
|
||||
"minimum": 100.0,
|
||||
@@ -447,7 +615,7 @@ func (t *WebFetchTool) Parameters() map[string]interface{} {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *WebFetchTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
|
||||
func (t *WebFetchTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
|
||||
urlStr, ok := args["url"].(string)
|
||||
if !ok {
|
||||
return ErrorResult("url is required")
|
||||
@@ -480,20 +648,17 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]interface{})
|
||||
|
||||
req.Header.Set("User-Agent", userAgent)
|
||||
|
||||
client := &http.Client{
|
||||
Timeout: 60 * time.Second,
|
||||
Transport: &http.Transport{
|
||||
MaxIdleConns: 10,
|
||||
IdleConnTimeout: 30 * time.Second,
|
||||
DisableCompression: false,
|
||||
TLSHandshakeTimeout: 15 * time.Second,
|
||||
},
|
||||
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||
if len(via) >= 5 {
|
||||
return fmt.Errorf("stopped after 5 redirects")
|
||||
}
|
||||
return nil
|
||||
},
|
||||
client, err := createHTTPClient(t.proxy, 60*time.Second)
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("failed to create HTTP client: %v", err))
|
||||
}
|
||||
|
||||
// Configure redirect handling
|
||||
client.CheckRedirect = func(req *http.Request, via []*http.Request) error {
|
||||
if len(via) >= 5 {
|
||||
return fmt.Errorf("stopped after 5 redirects")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
@@ -512,7 +677,7 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]interface{})
|
||||
var text, extractor string
|
||||
|
||||
if strings.Contains(contentType, "application/json") {
|
||||
var jsonData interface{}
|
||||
var jsonData any
|
||||
if err := json.Unmarshal(body, &jsonData); err == nil {
|
||||
formatted, _ := json.MarshalIndent(jsonData, "", " ")
|
||||
text = string(formatted)
|
||||
@@ -535,7 +700,7 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]interface{})
|
||||
text = text[:maxChars]
|
||||
}
|
||||
|
||||
result := map[string]interface{}{
|
||||
result := map[string]any{
|
||||
"url": urlStr,
|
||||
"status": resp.StatusCode,
|
||||
"extractor": extractor,
|
||||
@@ -547,7 +712,13 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]interface{})
|
||||
resultJSON, _ := json.MarshalIndent(result, "", " ")
|
||||
|
||||
return &ToolResult{
|
||||
ForLLM: fmt.Sprintf("Fetched %d bytes from %s (extractor: %s, truncated: %v)", len(text), urlStr, extractor, truncated),
|
||||
ForLLM: fmt.Sprintf(
|
||||
"Fetched %d bytes from %s (extractor: %s, truncated: %v)",
|
||||
len(text),
|
||||
urlStr,
|
||||
extractor,
|
||||
truncated,
|
||||
),
|
||||
ForUser: string(resultJSON),
|
||||
}
|
||||
}
|
||||
|
||||
+256
-12
@@ -7,6 +7,7 @@ import (
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestWebTool_WebFetch_Success verifies successful URL fetching
|
||||
@@ -20,7 +21,7 @@ func TestWebTool_WebFetch_Success(t *testing.T) {
|
||||
|
||||
tool := NewWebFetchTool(50000)
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
args := map[string]any{
|
||||
"url": server.URL,
|
||||
}
|
||||
|
||||
@@ -56,7 +57,7 @@ func TestWebTool_WebFetch_JSON(t *testing.T) {
|
||||
|
||||
tool := NewWebFetchTool(50000)
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
args := map[string]any{
|
||||
"url": server.URL,
|
||||
}
|
||||
|
||||
@@ -77,7 +78,7 @@ func TestWebTool_WebFetch_JSON(t *testing.T) {
|
||||
func TestWebTool_WebFetch_InvalidURL(t *testing.T) {
|
||||
tool := NewWebFetchTool(50000)
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
args := map[string]any{
|
||||
"url": "not-a-valid-url",
|
||||
}
|
||||
|
||||
@@ -98,7 +99,7 @@ func TestWebTool_WebFetch_InvalidURL(t *testing.T) {
|
||||
func TestWebTool_WebFetch_UnsupportedScheme(t *testing.T) {
|
||||
tool := NewWebFetchTool(50000)
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
args := map[string]any{
|
||||
"url": "ftp://example.com/file.txt",
|
||||
}
|
||||
|
||||
@@ -119,7 +120,7 @@ func TestWebTool_WebFetch_UnsupportedScheme(t *testing.T) {
|
||||
func TestWebTool_WebFetch_MissingURL(t *testing.T) {
|
||||
tool := NewWebFetchTool(50000)
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{}
|
||||
args := map[string]any{}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
@@ -147,7 +148,7 @@ func TestWebTool_WebFetch_Truncation(t *testing.T) {
|
||||
|
||||
tool := NewWebFetchTool(1000) // Limit to 1000 chars
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
args := map[string]any{
|
||||
"url": server.URL,
|
||||
}
|
||||
|
||||
@@ -159,7 +160,7 @@ func TestWebTool_WebFetch_Truncation(t *testing.T) {
|
||||
}
|
||||
|
||||
// ForUser should contain truncated content (not the full 20000 chars)
|
||||
resultMap := make(map[string]interface{})
|
||||
resultMap := make(map[string]any)
|
||||
json.Unmarshal([]byte(result.ForUser), &resultMap)
|
||||
if text, ok := resultMap["text"].(string); ok {
|
||||
if len(text) > 1100 { // Allow some margin
|
||||
@@ -191,7 +192,7 @@ func TestWebTool_WebSearch_NoApiKey(t *testing.T) {
|
||||
func TestWebTool_WebSearch_MissingQuery(t *testing.T) {
|
||||
tool := NewWebSearchTool(WebSearchToolOptions{BraveEnabled: true, BraveAPIKey: "test-key", BraveMaxResults: 5})
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{}
|
||||
args := map[string]any{}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
@@ -206,13 +207,17 @@ func TestWebTool_WebFetch_HTMLExtraction(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/html")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`<html><body><script>alert('test');</script><style>body{color:red;}</style><h1>Title</h1><p>Content</p></body></html>`))
|
||||
w.Write(
|
||||
[]byte(
|
||||
`<html><body><script>alert('test');</script><style>body{color:red;}</style><h1>Title</h1><p>Content</p></body></html>`,
|
||||
),
|
||||
)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
tool := NewWebFetchTool(50000)
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
args := map[string]any{
|
||||
"url": server.URL,
|
||||
}
|
||||
|
||||
@@ -251,7 +256,8 @@ func TestWebFetchTool_extractText(t *testing.T) {
|
||||
if len(lines) < 2 {
|
||||
t.Errorf("Expected multiple lines, got %d: %q", len(lines), got)
|
||||
}
|
||||
if !strings.Contains(got, "Title") || !strings.Contains(got, "Paragraph 1") || !strings.Contains(got, "Paragraph 2") {
|
||||
if !strings.Contains(got, "Title") || !strings.Contains(got, "Paragraph 1") ||
|
||||
!strings.Contains(got, "Paragraph 2") {
|
||||
t.Errorf("Missing expected text: %q", got)
|
||||
}
|
||||
},
|
||||
@@ -312,7 +318,7 @@ func TestWebFetchTool_extractText(t *testing.T) {
|
||||
func TestWebTool_WebFetch_MissingDomain(t *testing.T) {
|
||||
tool := NewWebFetchTool(50000)
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
args := map[string]any{
|
||||
"url": "https://",
|
||||
}
|
||||
|
||||
@@ -328,3 +334,241 @@ func TestWebTool_WebFetch_MissingDomain(t *testing.T) {
|
||||
t.Errorf("Expected domain error message, got ForLLM: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateHTTPClient_ProxyConfigured(t *testing.T) {
|
||||
client, err := createHTTPClient("http://127.0.0.1:7890", 12*time.Second)
|
||||
if err != nil {
|
||||
t.Fatalf("createHTTPClient() error: %v", err)
|
||||
}
|
||||
if client.Timeout != 12*time.Second {
|
||||
t.Fatalf("client.Timeout = %v, want %v", client.Timeout, 12*time.Second)
|
||||
}
|
||||
|
||||
tr, ok := client.Transport.(*http.Transport)
|
||||
if !ok {
|
||||
t.Fatalf("client.Transport type = %T, want *http.Transport", client.Transport)
|
||||
}
|
||||
if tr.Proxy == nil {
|
||||
t.Fatal("transport.Proxy is nil, want non-nil")
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("GET", "https://example.com", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("http.NewRequest() error: %v", err)
|
||||
}
|
||||
proxyURL, err := tr.Proxy(req)
|
||||
if err != nil {
|
||||
t.Fatalf("transport.Proxy(req) error: %v", err)
|
||||
}
|
||||
if proxyURL == nil || proxyURL.String() != "http://127.0.0.1:7890" {
|
||||
t.Fatalf("proxy URL = %v, want %q", proxyURL, "http://127.0.0.1:7890")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateHTTPClient_InvalidProxy(t *testing.T) {
|
||||
_, err := createHTTPClient("://bad-proxy", 10*time.Second)
|
||||
if err == nil {
|
||||
t.Fatal("createHTTPClient() expected error for invalid proxy URL, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateHTTPClient_Socks5ProxyConfigured(t *testing.T) {
|
||||
client, err := createHTTPClient("socks5://127.0.0.1:1080", 8*time.Second)
|
||||
if err != nil {
|
||||
t.Fatalf("createHTTPClient() error: %v", err)
|
||||
}
|
||||
|
||||
tr, ok := client.Transport.(*http.Transport)
|
||||
if !ok {
|
||||
t.Fatalf("client.Transport type = %T, want *http.Transport", client.Transport)
|
||||
}
|
||||
req, err := http.NewRequest("GET", "https://example.com", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("http.NewRequest() error: %v", err)
|
||||
}
|
||||
proxyURL, err := tr.Proxy(req)
|
||||
if err != nil {
|
||||
t.Fatalf("transport.Proxy(req) error: %v", err)
|
||||
}
|
||||
if proxyURL == nil || proxyURL.String() != "socks5://127.0.0.1:1080" {
|
||||
t.Fatalf("proxy URL = %v, want %q", proxyURL, "socks5://127.0.0.1:1080")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateHTTPClient_UnsupportedProxyScheme(t *testing.T) {
|
||||
_, err := createHTTPClient("ftp://127.0.0.1:21", 10*time.Second)
|
||||
if err == nil {
|
||||
t.Fatal("createHTTPClient() expected error for unsupported scheme, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "unsupported proxy scheme") {
|
||||
t.Fatalf("error = %q, want to contain %q", err.Error(), "unsupported proxy scheme")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateHTTPClient_ProxyFromEnvironmentWhenConfigEmpty(t *testing.T) {
|
||||
t.Setenv("HTTP_PROXY", "http://127.0.0.1:8888")
|
||||
t.Setenv("http_proxy", "http://127.0.0.1:8888")
|
||||
t.Setenv("HTTPS_PROXY", "http://127.0.0.1:8888")
|
||||
t.Setenv("https_proxy", "http://127.0.0.1:8888")
|
||||
t.Setenv("ALL_PROXY", "")
|
||||
t.Setenv("all_proxy", "")
|
||||
t.Setenv("NO_PROXY", "")
|
||||
t.Setenv("no_proxy", "")
|
||||
|
||||
client, err := createHTTPClient("", 10*time.Second)
|
||||
if err != nil {
|
||||
t.Fatalf("createHTTPClient() error: %v", err)
|
||||
}
|
||||
|
||||
tr, ok := client.Transport.(*http.Transport)
|
||||
if !ok {
|
||||
t.Fatalf("client.Transport type = %T, want *http.Transport", client.Transport)
|
||||
}
|
||||
if tr.Proxy == nil {
|
||||
t.Fatal("transport.Proxy is nil, want proxy function from environment")
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("GET", "https://example.com", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("http.NewRequest() error: %v", err)
|
||||
}
|
||||
if _, err := tr.Proxy(req); err != nil {
|
||||
t.Fatalf("transport.Proxy(req) error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewWebFetchToolWithProxy(t *testing.T) {
|
||||
tool := NewWebFetchToolWithProxy(1024, "http://127.0.0.1:7890")
|
||||
if tool.maxChars != 1024 {
|
||||
t.Fatalf("maxChars = %d, want %d", tool.maxChars, 1024)
|
||||
}
|
||||
if tool.proxy != "http://127.0.0.1:7890" {
|
||||
t.Fatalf("proxy = %q, want %q", tool.proxy, "http://127.0.0.1:7890")
|
||||
}
|
||||
|
||||
tool = NewWebFetchToolWithProxy(0, "http://127.0.0.1:7890")
|
||||
if tool.maxChars != 50000 {
|
||||
t.Fatalf("default maxChars = %d, want %d", tool.maxChars, 50000)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewWebSearchTool_PropagatesProxy(t *testing.T) {
|
||||
t.Run("perplexity", func(t *testing.T) {
|
||||
tool := NewWebSearchTool(WebSearchToolOptions{
|
||||
PerplexityEnabled: true,
|
||||
PerplexityAPIKey: "k",
|
||||
PerplexityMaxResults: 3,
|
||||
Proxy: "http://127.0.0.1:7890",
|
||||
})
|
||||
p, ok := tool.provider.(*PerplexitySearchProvider)
|
||||
if !ok {
|
||||
t.Fatalf("provider type = %T, want *PerplexitySearchProvider", tool.provider)
|
||||
}
|
||||
if p.proxy != "http://127.0.0.1:7890" {
|
||||
t.Fatalf("provider proxy = %q, want %q", p.proxy, "http://127.0.0.1:7890")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("brave", func(t *testing.T) {
|
||||
tool := NewWebSearchTool(WebSearchToolOptions{
|
||||
BraveEnabled: true,
|
||||
BraveAPIKey: "k",
|
||||
BraveMaxResults: 3,
|
||||
Proxy: "http://127.0.0.1:7890",
|
||||
})
|
||||
p, ok := tool.provider.(*BraveSearchProvider)
|
||||
if !ok {
|
||||
t.Fatalf("provider type = %T, want *BraveSearchProvider", tool.provider)
|
||||
}
|
||||
if p.proxy != "http://127.0.0.1:7890" {
|
||||
t.Fatalf("provider proxy = %q, want %q", p.proxy, "http://127.0.0.1:7890")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("duckduckgo", func(t *testing.T) {
|
||||
tool := NewWebSearchTool(WebSearchToolOptions{
|
||||
DuckDuckGoEnabled: true,
|
||||
DuckDuckGoMaxResults: 3,
|
||||
Proxy: "http://127.0.0.1:7890",
|
||||
})
|
||||
p, ok := tool.provider.(*DuckDuckGoSearchProvider)
|
||||
if !ok {
|
||||
t.Fatalf("provider type = %T, want *DuckDuckGoSearchProvider", tool.provider)
|
||||
}
|
||||
if p.proxy != "http://127.0.0.1:7890" {
|
||||
t.Fatalf("provider proxy = %q, want %q", p.proxy, "http://127.0.0.1:7890")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestWebTool_TavilySearch_Success verifies successful Tavily search
|
||||
func TestWebTool_TavilySearch_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != "POST" {
|
||||
t.Errorf("Expected POST request, got %s", r.Method)
|
||||
}
|
||||
if r.Header.Get("Content-Type") != "application/json" {
|
||||
t.Errorf("Expected Content-Type application/json, got %s", r.Header.Get("Content-Type"))
|
||||
}
|
||||
|
||||
// Verify payload
|
||||
var payload map[string]any
|
||||
json.NewDecoder(r.Body).Decode(&payload)
|
||||
if payload["api_key"] != "test-key" {
|
||||
t.Errorf("Expected api_key test-key, got %v", payload["api_key"])
|
||||
}
|
||||
if payload["query"] != "test query" {
|
||||
t.Errorf("Expected query 'test query', got %v", payload["query"])
|
||||
}
|
||||
|
||||
// Return mock response
|
||||
response := map[string]any{
|
||||
"results": []map[string]any{
|
||||
{
|
||||
"title": "Test Result 1",
|
||||
"url": "https://example.com/1",
|
||||
"content": "Content for result 1",
|
||||
},
|
||||
{
|
||||
"title": "Test Result 2",
|
||||
"url": "https://example.com/2",
|
||||
"content": "Content for result 2",
|
||||
},
|
||||
},
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(response)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
tool := NewWebSearchTool(WebSearchToolOptions{
|
||||
TavilyEnabled: true,
|
||||
TavilyAPIKey: "test-key",
|
||||
TavilyBaseURL: server.URL,
|
||||
TavilyMaxResults: 5,
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]any{
|
||||
"query": "test query",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Success should not be an error
|
||||
if result.IsError {
|
||||
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
// ForUser should contain result titles and URLs
|
||||
if !strings.Contains(result.ForUser, "Test Result 1") ||
|
||||
!strings.Contains(result.ForUser, "https://example.com/1") {
|
||||
t.Errorf("Expected results in output, got: %s", result.ForUser)
|
||||
}
|
||||
|
||||
// Should mention via Tavily
|
||||
if !strings.Contains(result.ForUser, "via Tavily") {
|
||||
t.Errorf("Expected 'via Tavily' in output, got: %s", result.ForUser)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user