Merge upstream/main into fix/bugfixes

Resolve conflicts:
- provider.go: keep upstream's serializeMessages (supersedes stripSystemParts)
- provider_test.go: keep upstream's serializeMessages tests
- loop_test.go: add slices import needed by upstream tests
- shell.go: merge PR's --format deny fix with upstream's block device
  pattern, safePaths, and absolutePathPattern
- shell_test.go: include tests from both branches

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
I Putu Eddy Irawan
2026-03-03 21:55:26 +07:00
119 changed files with 8055 additions and 1855 deletions
+5 -3
View File
@@ -3,6 +3,7 @@ package tools
import (
"context"
"fmt"
"strings"
"sync"
"time"
@@ -222,7 +223,8 @@ func (t *CronTool) listJobs() *ToolResult {
return SilentResult("No scheduled jobs")
}
result := "Scheduled jobs:\n"
var result strings.Builder
result.WriteString("Scheduled jobs:\n")
for _, j := range jobs {
var scheduleInfo string
if j.Schedule.Kind == "every" && j.Schedule.EveryMS != nil {
@@ -234,10 +236,10 @@ func (t *CronTool) listJobs() *ToolResult {
} else {
scheduleInfo = "unknown"
}
result += fmt.Sprintf("- %s (id: %s, %s)\n", j.Name, j.ID, scheduleInfo)
result.WriteString(fmt.Sprintf("- %s (id: %s, %s)\n", j.Name, j.ID, scheduleInfo))
}
return SilentResult(result)
return SilentResult(result.String())
}
func (t *CronTool) removeJob(args map[string]any) *ToolResult {
+11 -14
View File
@@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"io/fs"
"regexp"
"strings"
)
@@ -15,14 +16,12 @@ type EditFileTool struct {
}
// NewEditFileTool creates a new EditFileTool with optional directory restriction.
func NewEditFileTool(workspace string, restrict bool) *EditFileTool {
var fs fileSystem
if restrict {
fs = &sandboxFs{workspace: workspace}
} else {
fs = &hostFs{}
func NewEditFileTool(workspace string, restrict bool, allowPaths ...[]*regexp.Regexp) *EditFileTool {
var patterns []*regexp.Regexp
if len(allowPaths) > 0 {
patterns = allowPaths[0]
}
return &EditFileTool{fs: fs}
return &EditFileTool{fs: buildFs(workspace, restrict, patterns)}
}
func (t *EditFileTool) Name() string {
@@ -80,14 +79,12 @@ type AppendFileTool struct {
fs fileSystem
}
func NewAppendFileTool(workspace string, restrict bool) *AppendFileTool {
var fs fileSystem
if restrict {
fs = &sandboxFs{workspace: workspace}
} else {
fs = &hostFs{}
func NewAppendFileTool(workspace string, restrict bool, allowPaths ...[]*regexp.Regexp) *AppendFileTool {
var patterns []*regexp.Regexp
if len(allowPaths) > 0 {
patterns = allowPaths[0]
}
return &AppendFileTool{fs: fs}
return &AppendFileTool{fs: buildFs(workspace, restrict, patterns)}
}
func (t *AppendFileTool) Name() string {
+67 -21
View File
@@ -6,6 +6,7 @@ import (
"io/fs"
"os"
"path/filepath"
"regexp"
"strings"
"time"
@@ -87,14 +88,12 @@ type ReadFileTool struct {
fs fileSystem
}
func NewReadFileTool(workspace string, restrict bool) *ReadFileTool {
var fs fileSystem
if restrict {
fs = &sandboxFs{workspace: workspace}
} else {
fs = &hostFs{}
func NewReadFileTool(workspace string, restrict bool, allowPaths ...[]*regexp.Regexp) *ReadFileTool {
var patterns []*regexp.Regexp
if len(allowPaths) > 0 {
patterns = allowPaths[0]
}
return &ReadFileTool{fs: fs}
return &ReadFileTool{fs: buildFs(workspace, restrict, patterns)}
}
func (t *ReadFileTool) Name() string {
@@ -135,14 +134,12 @@ type WriteFileTool struct {
fs fileSystem
}
func NewWriteFileTool(workspace string, restrict bool) *WriteFileTool {
var fs fileSystem
if restrict {
fs = &sandboxFs{workspace: workspace}
} else {
fs = &hostFs{}
func NewWriteFileTool(workspace string, restrict bool, allowPaths ...[]*regexp.Regexp) *WriteFileTool {
var patterns []*regexp.Regexp
if len(allowPaths) > 0 {
patterns = allowPaths[0]
}
return &WriteFileTool{fs: fs}
return &WriteFileTool{fs: buildFs(workspace, restrict, patterns)}
}
func (t *WriteFileTool) Name() string {
@@ -192,14 +189,12 @@ type ListDirTool struct {
fs fileSystem
}
func NewListDirTool(workspace string, restrict bool) *ListDirTool {
var fs fileSystem
if restrict {
fs = &sandboxFs{workspace: workspace}
} else {
fs = &hostFs{}
func NewListDirTool(workspace string, restrict bool, allowPaths ...[]*regexp.Regexp) *ListDirTool {
var patterns []*regexp.Regexp
if len(allowPaths) > 0 {
patterns = allowPaths[0]
}
return &ListDirTool{fs: fs}
return &ListDirTool{fs: buildFs(workspace, restrict, patterns)}
}
func (t *ListDirTool) Name() string {
@@ -394,6 +389,57 @@ func (r *sandboxFs) ReadDir(path string) ([]os.DirEntry, error) {
return entries, err
}
// whitelistFs wraps a sandboxFs and allows access to specific paths outside
// the workspace when they match any of the provided patterns.
type whitelistFs struct {
sandbox *sandboxFs
host hostFs
patterns []*regexp.Regexp
}
func (w *whitelistFs) matches(path string) bool {
for _, p := range w.patterns {
if p.MatchString(path) {
return true
}
}
return false
}
func (w *whitelistFs) ReadFile(path string) ([]byte, error) {
if w.matches(path) {
return w.host.ReadFile(path)
}
return w.sandbox.ReadFile(path)
}
func (w *whitelistFs) WriteFile(path string, data []byte) error {
if w.matches(path) {
return w.host.WriteFile(path, data)
}
return w.sandbox.WriteFile(path, data)
}
func (w *whitelistFs) ReadDir(path string) ([]os.DirEntry, error) {
if w.matches(path) {
return w.host.ReadDir(path)
}
return w.sandbox.ReadDir(path)
}
// buildFs returns the appropriate fileSystem implementation based on restriction
// settings and optional path whitelist patterns.
func buildFs(workspace string, restrict bool, patterns []*regexp.Regexp) fileSystem {
if !restrict {
return &hostFs{}
}
sandbox := &sandboxFs{workspace: workspace}
if len(patterns) > 0 {
return &whitelistFs{sandbox: sandbox, patterns: patterns}
}
return sandbox
}
// Helper to get a safe relative path for os.Root usage
func getSafeRelPath(workspace, path string) (string, error) {
if workspace == "" {
+34
View File
@@ -5,6 +5,7 @@ import (
"io"
"os"
"path/filepath"
"regexp"
"strings"
"testing"
@@ -486,3 +487,36 @@ func TestRootRW_Write(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, newData, content)
}
// TestWhitelistFs_AllowsMatchingPaths verifies that whitelistFs allows access to
// paths matching the whitelist patterns while blocking non-matching paths.
func TestWhitelistFs_AllowsMatchingPaths(t *testing.T) {
workspace := t.TempDir()
outsideDir := t.TempDir()
outsideFile := filepath.Join(outsideDir, "allowed.txt")
os.WriteFile(outsideFile, []byte("outside content"), 0o644)
// Pattern allows access to the outsideDir.
patterns := []*regexp.Regexp{regexp.MustCompile(`^` + regexp.QuoteMeta(outsideDir))}
tool := NewReadFileTool(workspace, true, patterns)
// Read from whitelisted path should succeed.
result := tool.Execute(context.Background(), map[string]any{"path": outsideFile})
if result.IsError {
t.Errorf("expected whitelisted path to be readable, got: %s", result.ForLLM)
}
if !strings.Contains(result.ForLLM, "outside content") {
t.Errorf("expected file content, got: %s", result.ForLLM)
}
// Read from non-whitelisted path outside workspace should fail.
otherDir := t.TempDir()
otherFile := filepath.Join(otherDir, "blocked.txt")
os.WriteFile(otherFile, []byte("blocked"), 0o644)
result = tool.Execute(context.Background(), map[string]any{"path": otherFile})
if !result.IsError {
t.Errorf("expected non-whitelisted path to be blocked, got: %s", result.ForLLM)
}
}
+246
View File
@@ -0,0 +1,246 @@
package tools
import (
"context"
"encoding/json"
"fmt"
"hash/fnv"
"strings"
"github.com/modelcontextprotocol/go-sdk/mcp"
)
// MCPManager defines the interface for MCP manager operations
// This allows for easier testing with mock implementations
type MCPManager interface {
CallTool(
ctx context.Context,
serverName, toolName string,
arguments map[string]any,
) (*mcp.CallToolResult, error)
}
// MCPTool wraps an MCP tool to implement the Tool interface
type MCPTool struct {
manager MCPManager
serverName string
tool *mcp.Tool
}
// NewMCPTool creates a new MCP tool wrapper
func NewMCPTool(manager MCPManager, serverName string, tool *mcp.Tool) *MCPTool {
return &MCPTool{
manager: manager,
serverName: serverName,
tool: tool,
}
}
// sanitizeIdentifierComponent normalizes a string so it can be safely used
// as part of a tool/function identifier for downstream providers.
// It:
// - lowercases the string
// - replaces any character not in [a-z0-9_-] with '_'
// - collapses multiple consecutive '_' into a single '_'
// - trims leading/trailing '_'
// - falls back to "unnamed" if the result is empty
// - truncates overly long components to a reasonable length
func sanitizeIdentifierComponent(s string) string {
const maxLen = 64
s = strings.ToLower(s)
var b strings.Builder
b.Grow(len(s))
prevUnderscore := false
for _, r := range s {
isAllowed := (r >= 'a' && r <= 'z') ||
(r >= '0' && r <= '9') ||
r == '_' || r == '-'
if !isAllowed {
// Normalize any disallowed character to '_'
if !prevUnderscore {
b.WriteRune('_')
prevUnderscore = true
}
continue
}
if r == '_' {
if prevUnderscore {
continue
}
prevUnderscore = true
} else {
prevUnderscore = false
}
b.WriteRune(r)
}
result := strings.Trim(b.String(), "_")
if result == "" {
result = "unnamed"
}
if len(result) > maxLen {
result = result[:maxLen]
}
return result
}
// Name returns the tool name, prefixed with the server name.
// The total length is capped at 64 characters (OpenAI-compatible API limit).
// A short hash of the original (unsanitized) server and tool names is appended
// whenever sanitization is lossy or the name is truncated, ensuring that two
// names which differ only in disallowed characters remain distinct after sanitization.
func (t *MCPTool) Name() string {
// Prefix with server name to avoid conflicts, and sanitize components
sanitizedServer := sanitizeIdentifierComponent(t.serverName)
sanitizedTool := sanitizeIdentifierComponent(t.tool.Name)
full := fmt.Sprintf("mcp_%s_%s", sanitizedServer, sanitizedTool)
// Check if sanitization was lossless (only lowercasing, no char replacement/truncation)
lossless := strings.ToLower(t.serverName) == sanitizedServer &&
strings.ToLower(t.tool.Name) == sanitizedTool
const maxTotal = 64
if lossless && len(full) <= maxTotal {
return full
}
// Sanitization was lossy or name too long: append hash of the ORIGINAL names
// (not the sanitized names) so different originals always yield different hashes.
h := fnv.New32a()
_, _ = h.Write([]byte(t.serverName + "\x00" + t.tool.Name))
suffix := fmt.Sprintf("%08x", h.Sum32()) // 8 chars
base := full
if len(base) > maxTotal-9 {
base = strings.TrimRight(full[:maxTotal-9], "_")
}
return base + "_" + suffix
}
// Description returns the tool description
func (t *MCPTool) Description() string {
desc := t.tool.Description
if desc == "" {
desc = fmt.Sprintf("MCP tool from %s server", t.serverName)
}
// Add server info to description
return fmt.Sprintf("[MCP:%s] %s", t.serverName, desc)
}
// Parameters returns the tool parameters schema
func (t *MCPTool) Parameters() map[string]any {
// The InputSchema is already a JSON Schema object
schema := t.tool.InputSchema
// Handle nil schema
if schema == nil {
return map[string]any{
"type": "object",
"properties": map[string]any{},
"required": []string{},
}
}
// Try direct conversion first (fast path)
if schemaMap, ok := schema.(map[string]any); ok {
return schemaMap
}
// Handle json.RawMessage and []byte - unmarshal directly
var jsonData []byte
if rawMsg, ok := schema.(json.RawMessage); ok {
jsonData = rawMsg
} else if bytes, ok := schema.([]byte); ok {
jsonData = bytes
}
if jsonData != nil {
var result map[string]any
if err := json.Unmarshal(jsonData, &result); err == nil {
return result
}
// Fallback on error
return map[string]any{
"type": "object",
"properties": map[string]any{},
"required": []string{},
}
}
// For other types (structs, etc.), convert via JSON marshal/unmarshal
var err error
jsonData, err = json.Marshal(schema)
if err != nil {
// Fallback to empty schema if marshaling fails
return map[string]any{
"type": "object",
"properties": map[string]any{},
"required": []string{},
}
}
var result map[string]any
if err := json.Unmarshal(jsonData, &result); err != nil {
// Fallback to empty schema if unmarshaling fails
return map[string]any{
"type": "object",
"properties": map[string]any{},
"required": []string{},
}
}
return result
}
// Execute executes the MCP tool
func (t *MCPTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
result, err := t.manager.CallTool(ctx, t.serverName, t.tool.Name, args)
if err != nil {
return ErrorResult(fmt.Sprintf("MCP tool execution failed: %v", err)).WithError(err)
}
if result == nil {
nilErr := fmt.Errorf("MCP tool returned nil result without error")
return ErrorResult("MCP tool execution failed: nil result").WithError(nilErr)
}
// Handle error result from server
if result.IsError {
errMsg := extractContentText(result.Content)
return ErrorResult(fmt.Sprintf("MCP tool returned error: %s", errMsg)).
WithError(fmt.Errorf("MCP tool error: %s", errMsg))
}
// Extract text content from result
output := extractContentText(result.Content)
return &ToolResult{
ForLLM: output,
IsError: false,
}
}
// extractContentText extracts text from MCP content array
func extractContentText(content []mcp.Content) string {
var parts []string
for _, c := range content {
switch v := c.(type) {
case *mcp.TextContent:
parts = append(parts, v.Text)
case *mcp.ImageContent:
// For images, just indicate that an image was returned
parts = append(parts, fmt.Sprintf("[Image: %s]", v.MIMEType))
default:
// For other content types, use string representation
parts = append(parts, fmt.Sprintf("[Content: %T]", v))
}
}
return strings.Join(parts, "\n")
}
+492
View File
@@ -0,0 +1,492 @@
package tools
import (
"context"
"fmt"
"strings"
"testing"
"github.com/modelcontextprotocol/go-sdk/mcp"
)
// MockMCPManager is a mock implementation of MCPManager interface for testing
type MockMCPManager struct {
callToolFunc func(ctx context.Context, serverName, toolName string, arguments map[string]any) (*mcp.CallToolResult, error)
}
func (m *MockMCPManager) CallTool(
ctx context.Context,
serverName, toolName string,
arguments map[string]any,
) (*mcp.CallToolResult, error) {
if m.callToolFunc != nil {
return m.callToolFunc(ctx, serverName, toolName, arguments)
}
return &mcp.CallToolResult{
Content: []mcp.Content{
&mcp.TextContent{Text: "mock result"},
},
IsError: false,
}, nil
}
// TestNewMCPTool verifies MCP tool creation
func TestNewMCPTool(t *testing.T) {
manager := &MockMCPManager{}
tool := &mcp.Tool{
Name: "test_tool",
Description: "A test tool",
InputSchema: map[string]any{
"type": "object",
"properties": map[string]any{
"input": map[string]any{
"type": "string",
"description": "Test input",
},
},
},
}
mcpTool := NewMCPTool(manager, "test_server", tool)
if mcpTool == nil {
t.Fatal("NewMCPTool should not return nil")
}
// Verify tool properties we can access
if mcpTool.Name() != "mcp_test_server_test_tool" {
t.Errorf("Expected tool name with prefix, got '%s'", mcpTool.Name())
}
}
// TestMCPTool_Name verifies tool name with server prefix
func TestMCPTool_Name(t *testing.T) {
tests := []struct {
name string
serverName string
toolName string
expected string
}{
{
name: "simple name",
serverName: "github",
toolName: "create_issue",
expected: "mcp_github_create_issue",
},
{
name: "filesystem server",
serverName: "filesystem",
toolName: "read_file",
expected: "mcp_filesystem_read_file",
},
{
name: "remote server",
serverName: "remote-api",
toolName: "fetch_data",
expected: "mcp_remote-api_fetch_data",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
manager := &MockMCPManager{}
tool := &mcp.Tool{Name: tt.toolName}
mcpTool := NewMCPTool(manager, tt.serverName, tool)
result := mcpTool.Name()
if result != tt.expected {
t.Errorf("Expected name '%s', got '%s'", tt.expected, result)
}
})
}
}
// TestMCPTool_Description verifies tool description generation
func TestMCPTool_Description(t *testing.T) {
tests := []struct {
name string
serverName string
toolDescription string
expectContains []string
}{
{
name: "with description",
serverName: "github",
toolDescription: "Create a GitHub issue",
expectContains: []string{"[MCP:github]", "Create a GitHub issue"},
},
{
name: "empty description",
serverName: "filesystem",
toolDescription: "",
expectContains: []string{"[MCP:filesystem]", "MCP tool from filesystem server"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
manager := &MockMCPManager{}
tool := &mcp.Tool{
Name: "test_tool",
Description: tt.toolDescription,
}
mcpTool := NewMCPTool(manager, tt.serverName, tool)
result := mcpTool.Description()
for _, expected := range tt.expectContains {
if !strings.Contains(result, expected) {
t.Errorf("Description should contain '%s', got: %s", expected, result)
}
}
})
}
}
// TestMCPTool_Parameters verifies parameter schema conversion
func TestMCPTool_Parameters(t *testing.T) {
tests := []struct {
name string
inputSchema any
expectType string
checkProperty string
expectProperty bool
}{
{
name: "map schema",
inputSchema: map[string]any{
"type": "object",
"properties": map[string]any{
"query": map[string]any{
"type": "string",
"description": "Search query",
},
},
"required": []string{"query"},
},
expectType: "object",
checkProperty: "query",
expectProperty: true,
},
{
name: "nil schema",
inputSchema: nil,
expectType: "object",
expectProperty: false,
},
{
name: "json.RawMessage schema",
inputSchema: []byte(`{
"type": "object",
"properties": {
"repo": {
"type": "string",
"description": "Repository name"
},
"stars": {
"type": "integer",
"description": "Minimum stars"
}
},
"required": ["repo"]
}`),
expectType: "object",
checkProperty: "repo",
expectProperty: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
manager := &MockMCPManager{}
tool := &mcp.Tool{
Name: "test_tool",
InputSchema: tt.inputSchema,
}
mcpTool := NewMCPTool(manager, "test_server", tool)
params := mcpTool.Parameters()
if params == nil {
t.Fatal("Parameters should not be nil")
}
if params["type"] != tt.expectType {
t.Errorf("Expected type '%s', got '%v'", tt.expectType, params["type"])
}
// Check if property exists when expected
if tt.checkProperty != "" {
properties, ok := params["properties"].(map[string]any)
if !ok && tt.expectProperty {
t.Errorf("Expected properties to be a map")
return
}
if ok {
_, hasProperty := properties[tt.checkProperty]
if hasProperty != tt.expectProperty {
t.Errorf("Expected property '%s' existence: %v, got: %v",
tt.checkProperty, tt.expectProperty, hasProperty)
}
}
}
})
}
}
// TestMCPTool_Execute_Success tests successful tool execution
func TestMCPTool_Execute_Success(t *testing.T) {
manager := &MockMCPManager{
callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]any) (*mcp.CallToolResult, error) {
// Verify correct parameters passed
if serverName != "github" {
t.Errorf("Expected serverName 'github', got '%s'", serverName)
}
if toolName != "search_repos" {
t.Errorf("Expected toolName 'search_repos', got '%s'", toolName)
}
return &mcp.CallToolResult{
Content: []mcp.Content{
&mcp.TextContent{Text: "Found 3 repositories"},
},
IsError: false,
}, nil
},
}
tool := &mcp.Tool{
Name: "search_repos",
Description: "Search GitHub repositories",
}
mcpTool := NewMCPTool(manager, "github", tool)
ctx := context.Background()
args := map[string]any{
"query": "golang mcp",
}
result := mcpTool.Execute(ctx, args)
if result == nil {
t.Fatal("Result should not be nil")
}
if result.IsError {
t.Errorf("Expected no error, got error: %s", result.ForLLM)
}
if result.ForLLM != "Found 3 repositories" {
t.Errorf("Expected 'Found 3 repositories', got '%s'", result.ForLLM)
}
}
// TestMCPTool_Execute_ManagerError tests execution when manager returns error
func TestMCPTool_Execute_ManagerError(t *testing.T) {
manager := &MockMCPManager{
callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]any) (*mcp.CallToolResult, error) {
return nil, fmt.Errorf("connection failed")
},
}
tool := &mcp.Tool{Name: "test_tool"}
mcpTool := NewMCPTool(manager, "test_server", tool)
ctx := context.Background()
result := mcpTool.Execute(ctx, map[string]any{})
if result == nil {
t.Fatal("Result should not be nil")
}
if !result.IsError {
t.Error("Expected IsError to be true")
}
if !strings.Contains(result.ForLLM, "MCP tool execution failed") {
t.Errorf("Error message should mention execution failure, got: %s", result.ForLLM)
}
if !strings.Contains(result.ForLLM, "connection failed") {
t.Errorf("Error message should include original error, got: %s", result.ForLLM)
}
}
// TestMCPTool_Execute_ServerError tests execution when server returns error
func TestMCPTool_Execute_ServerError(t *testing.T) {
manager := &MockMCPManager{
callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]any) (*mcp.CallToolResult, error) {
return &mcp.CallToolResult{
Content: []mcp.Content{
&mcp.TextContent{Text: "Invalid API key"},
},
IsError: true,
}, nil
},
}
tool := &mcp.Tool{Name: "test_tool"}
mcpTool := NewMCPTool(manager, "test_server", tool)
ctx := context.Background()
result := mcpTool.Execute(ctx, map[string]any{})
if result == nil {
t.Fatal("Result should not be nil")
}
if !result.IsError {
t.Error("Expected IsError to be true")
}
if !strings.Contains(result.ForLLM, "MCP tool returned error") {
t.Errorf("Error message should mention server error, got: %s", result.ForLLM)
}
if !strings.Contains(result.ForLLM, "Invalid API key") {
t.Errorf("Error message should include server message, got: %s", result.ForLLM)
}
}
// TestMCPTool_Execute_MultipleContent tests execution with multiple content items
func TestMCPTool_Execute_MultipleContent(t *testing.T) {
manager := &MockMCPManager{
callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]any) (*mcp.CallToolResult, error) {
return &mcp.CallToolResult{
Content: []mcp.Content{
&mcp.TextContent{Text: "First line"},
&mcp.TextContent{Text: "Second line"},
&mcp.TextContent{Text: "Third line"},
},
IsError: false,
}, nil
},
}
tool := &mcp.Tool{Name: "multi_output"}
mcpTool := NewMCPTool(manager, "test_server", tool)
ctx := context.Background()
result := mcpTool.Execute(ctx, map[string]any{})
if result.IsError {
t.Errorf("Expected no error, got: %s", result.ForLLM)
}
expected := "First line\nSecond line\nThird line"
if result.ForLLM != expected {
t.Errorf("Expected '%s', got '%s'", expected, result.ForLLM)
}
}
// TestExtractContentText_TextContent tests text content extraction
func TestExtractContentText_TextContent(t *testing.T) {
content := []mcp.Content{
&mcp.TextContent{Text: "Hello World"},
&mcp.TextContent{Text: "Second message"},
}
result := extractContentText(content)
expected := "Hello World\nSecond message"
if result != expected {
t.Errorf("Expected '%s', got '%s'", expected, result)
}
}
// TestExtractContentText_ImageContent tests image content extraction
func TestExtractContentText_ImageContent(t *testing.T) {
content := []mcp.Content{
&mcp.ImageContent{
Data: []byte("base64data"),
MIMEType: "image/png",
},
}
result := extractContentText(content)
if !strings.Contains(result, "[Image:") {
t.Errorf("Expected image indicator, got: %s", result)
}
if !strings.Contains(result, "image/png") {
t.Errorf("Expected MIME type in output, got: %s", result)
}
}
// TestExtractContentText_MixedContent tests mixed content types
func TestExtractContentText_MixedContent(t *testing.T) {
content := []mcp.Content{
&mcp.TextContent{Text: "Description"},
&mcp.ImageContent{
Data: []byte("data"),
MIMEType: "image/jpeg",
},
&mcp.TextContent{Text: "More text"},
}
result := extractContentText(content)
if !strings.Contains(result, "Description") {
t.Errorf("Should contain text content, got: %s", result)
}
if !strings.Contains(result, "[Image:") {
t.Errorf("Should contain image indicator, got: %s", result)
}
if !strings.Contains(result, "More text") {
t.Errorf("Should contain second text, got: %s", result)
}
}
// TestExtractContentText_EmptyContent tests empty content array
func TestExtractContentText_EmptyContent(t *testing.T) {
content := []mcp.Content{}
result := extractContentText(content)
if result != "" {
t.Errorf("Expected empty string for empty content, got: %s", result)
}
}
// TestMCPTool_InterfaceCompliance verifies MCPTool implements Tool interface
func TestMCPTool_InterfaceCompliance(t *testing.T) {
manager := &MockMCPManager{}
tool := &mcp.Tool{Name: "test"}
mcpTool := NewMCPTool(manager, "test_server", tool)
// Verify it implements Tool interface
var _ Tool = mcpTool
}
// TestMCPTool_Parameters_MapSchema tests schema that's already a map
func TestMCPTool_Parameters_MapSchema(t *testing.T) {
manager := &MockMCPManager{}
schema := map[string]any{
"type": "object",
"properties": map[string]any{
"name": map[string]any{
"type": "string",
"description": "The name parameter",
},
},
"required": []string{"name"},
}
tool := &mcp.Tool{
Name: "test_tool",
InputSchema: schema,
}
mcpTool := NewMCPTool(manager, "test_server", tool)
params := mcpTool.Parameters()
// Should return the schema as-is when it's already a map
if params["type"] != "object" {
t.Errorf("Expected type 'object', got '%v'", params["type"])
}
props, ok := params["properties"].(map[string]any)
if !ok {
t.Error("Properties should be a map")
}
nameParam, ok := props["name"].(map[string]any)
if !ok {
t.Error("Name parameter should exist")
}
if nameParam["type"] != "string" {
t.Errorf("Name type should be 'string', got '%v'", nameParam["type"])
}
}
+6 -1
View File
@@ -25,7 +25,12 @@ func NewToolRegistry() *ToolRegistry {
func (r *ToolRegistry) Register(tool Tool) {
r.mu.Lock()
defer r.mu.Unlock()
r.tools[tool.Name()] = tool
name := tool.Name()
if _, exists := r.tools[name]; exists {
logger.WarnCF("tools", "Tool registration overwrites existing tool",
map[string]any{"name": name})
}
r.tools[name] = tool
}
func (r *ToolRegistry) Get(name string) (Tool, bool) {
+1 -1
View File
@@ -329,7 +329,7 @@ func TestToolRegistry_ConcurrentAccess(t *testing.T) {
r := NewToolRegistry()
var wg sync.WaitGroup
for i := 0; i < 50; i++ {
for i := range 50 {
wg.Add(1)
go func(n int) {
defer wg.Done()
+96 -50
View File
@@ -21,53 +21,77 @@ type ExecTool struct {
timeout time.Duration
denyPatterns []*regexp.Regexp
allowPatterns []*regexp.Regexp
customAllowPatterns []*regexp.Regexp
restrictToWorkspace bool
}
var defaultDenyPatterns = []*regexp.Regexp{
regexp.MustCompile(`\brm\s+-[rf]{1,2}\b`),
regexp.MustCompile(`\bdel\s+/[fq]\b`),
regexp.MustCompile(`\brmdir\s+/s\b`),
regexp.MustCompile(`(?:^|[;&|]\s*|\s+)(format|mkfs|diskpart)\s`), // Match disk wiping commands, avoid matching --format flags
regexp.MustCompile(`\bdd\s+if=`),
regexp.MustCompile(`>\s*/dev/sd[a-z]\b`), // Block writes to disk devices (but allow /dev/null)
regexp.MustCompile(`\b(shutdown|reboot|poweroff)\b`),
regexp.MustCompile(`:\(\)\s*\{.*\};\s*:`),
regexp.MustCompile(`\$\([^)]+\)`),
regexp.MustCompile(`\$\{[^}]+\}`),
regexp.MustCompile("`[^`]+`"),
regexp.MustCompile(`\|\s*sh\b`),
regexp.MustCompile(`\|\s*bash\b`),
regexp.MustCompile(`;\s*rm\s+-[rf]`),
regexp.MustCompile(`&&\s*rm\s+-[rf]`),
regexp.MustCompile(`\|\|\s*rm\s+-[rf]`),
regexp.MustCompile(`>\s*/dev/null\s*>&?\s*\d?`),
regexp.MustCompile(`<<\s*EOF`),
regexp.MustCompile(`\$\(\s*cat\s+`),
regexp.MustCompile(`\$\(\s*curl\s+`),
regexp.MustCompile(`\$\(\s*wget\s+`),
regexp.MustCompile(`\$\(\s*which\s+`),
regexp.MustCompile(`\bsudo\b`),
regexp.MustCompile(`\bchmod\s+[0-7]{3,4}\b`),
regexp.MustCompile(`\bchown\b`),
regexp.MustCompile(`\bpkill\b`),
regexp.MustCompile(`\bkillall\b`),
regexp.MustCompile(`\bkill\s+-[9]\b`),
regexp.MustCompile(`\bcurl\b.*\|\s*(sh|bash)`),
regexp.MustCompile(`\bwget\b.*\|\s*(sh|bash)`),
regexp.MustCompile(`\bnpm\s+install\s+-g\b`),
regexp.MustCompile(`\bpip\s+install\s+--user\b`),
regexp.MustCompile(`\bapt\s+(install|remove|purge)\b`),
regexp.MustCompile(`\byum\s+(install|remove)\b`),
regexp.MustCompile(`\bdnf\s+(install|remove)\b`),
regexp.MustCompile(`\bdocker\s+run\b`),
regexp.MustCompile(`\bdocker\s+exec\b`),
regexp.MustCompile(`\bgit\s+push\b`),
regexp.MustCompile(`\bgit\s+force\b`),
regexp.MustCompile(`\bssh\b.*@`),
regexp.MustCompile(`\beval\b`),
regexp.MustCompile(`\bsource\s+.*\.sh\b`),
}
var (
defaultDenyPatterns = []*regexp.Regexp{
regexp.MustCompile(`\brm\s+-[rf]{1,2}\b`),
regexp.MustCompile(`\bdel\s+/[fq]\b`),
regexp.MustCompile(`\brmdir\s+/s\b`),
// Match disk wiping commands, avoid matching --format flags
regexp.MustCompile(
`(?:^|[;&|]\s*|\s+)(format|mkfs|diskpart)\s`,
),
regexp.MustCompile(`\bdd\s+if=`),
// Block writes to block devices (all common naming schemes).
regexp.MustCompile(
`>\s*/dev/(sd[a-z]|hd[a-z]|vd[a-z]|xvd[a-z]|nvme\d|mmcblk\d|loop\d|dm-\d|md\d|sr\d|nbd\d)`,
),
regexp.MustCompile(`\b(shutdown|reboot|poweroff)\b`),
regexp.MustCompile(`:\(\)\s*\{.*\};\s*:`),
regexp.MustCompile(`\$\([^)]+\)`),
regexp.MustCompile(`\$\{[^}]+\}`),
regexp.MustCompile("`[^`]+`"),
regexp.MustCompile(`\|\s*sh\b`),
regexp.MustCompile(`\|\s*bash\b`),
regexp.MustCompile(`;\s*rm\s+-[rf]`),
regexp.MustCompile(`&&\s*rm\s+-[rf]`),
regexp.MustCompile(`\|\|\s*rm\s+-[rf]`),
regexp.MustCompile(`<<\s*EOF`),
regexp.MustCompile(`\$\(\s*cat\s+`),
regexp.MustCompile(`\$\(\s*curl\s+`),
regexp.MustCompile(`\$\(\s*wget\s+`),
regexp.MustCompile(`\$\(\s*which\s+`),
regexp.MustCompile(`\bsudo\b`),
regexp.MustCompile(`\bchmod\s+[0-7]{3,4}\b`),
regexp.MustCompile(`\bchown\b`),
regexp.MustCompile(`\bpkill\b`),
regexp.MustCompile(`\bkillall\b`),
regexp.MustCompile(`\bkill\s+-[9]\b`),
regexp.MustCompile(`\bcurl\b.*\|\s*(sh|bash)`),
regexp.MustCompile(`\bwget\b.*\|\s*(sh|bash)`),
regexp.MustCompile(`\bnpm\s+install\s+-g\b`),
regexp.MustCompile(`\bpip\s+install\s+--user\b`),
regexp.MustCompile(`\bapt\s+(install|remove|purge)\b`),
regexp.MustCompile(`\byum\s+(install|remove)\b`),
regexp.MustCompile(`\bdnf\s+(install|remove)\b`),
regexp.MustCompile(`\bdocker\s+run\b`),
regexp.MustCompile(`\bdocker\s+exec\b`),
regexp.MustCompile(`\bgit\s+push\b`),
regexp.MustCompile(`\bgit\s+force\b`),
regexp.MustCompile(`\bssh\b.*@`),
regexp.MustCompile(`\beval\b`),
regexp.MustCompile(`\bsource\s+.*\.sh\b`),
}
// absolutePathPattern matches absolute file paths in commands (Unix and Windows).
absolutePathPattern = regexp.MustCompile(`[A-Za-z]:\\[^\\\"']+|/[^\s\"']+`)
// safePaths are kernel pseudo-devices that are always safe to reference in
// commands, regardless of workspace restriction. They contain no user data
// and cannot cause destructive writes.
safePaths = map[string]bool{
"/dev/null": true,
"/dev/zero": true,
"/dev/random": true,
"/dev/urandom": true,
"/dev/stdin": true,
"/dev/stdout": true,
"/dev/stderr": true,
}
)
func NewExecTool(workingDir string, restrict bool) (*ExecTool, error) {
return NewExecToolWithConfig(workingDir, restrict, nil)
@@ -75,6 +99,7 @@ func NewExecTool(workingDir string, restrict bool) (*ExecTool, error) {
func NewExecToolWithConfig(workingDir string, restrict bool, config *config.Config) (*ExecTool, error) {
denyPatterns := make([]*regexp.Regexp, 0)
customAllowPatterns := make([]*regexp.Regexp, 0)
if config != nil {
execConfig := config.Tools.Exec
@@ -95,6 +120,13 @@ func NewExecToolWithConfig(workingDir string, restrict bool, config *config.Conf
// If deny patterns are disabled, we won't add any patterns, allowing all commands.
fmt.Println("Warning: deny patterns are disabled. All commands will be allowed.")
}
for _, pattern := range execConfig.CustomAllowPatterns {
re, err := regexp.Compile(pattern)
if err != nil {
return nil, fmt.Errorf("invalid custom allow pattern %q: %w", pattern, err)
}
customAllowPatterns = append(customAllowPatterns, re)
}
} else {
denyPatterns = append(denyPatterns, defaultDenyPatterns...)
}
@@ -104,6 +136,7 @@ func NewExecToolWithConfig(workingDir string, restrict bool, config *config.Conf
timeout: 60 * time.Second,
denyPatterns: denyPatterns,
allowPatterns: nil,
customAllowPatterns: customAllowPatterns,
restrictToWorkspace: restrict,
}, nil
}
@@ -258,9 +291,20 @@ func (t *ExecTool) guardCommand(command, cwd string) string {
cmd := strings.TrimSpace(command)
lower := strings.ToLower(cmd)
for _, pattern := range t.denyPatterns {
// Custom allow patterns exempt a command from deny checks.
explicitlyAllowed := false
for _, pattern := range t.customAllowPatterns {
if pattern.MatchString(lower) {
return "Command blocked by safety guard (dangerous pattern detected)"
explicitlyAllowed = true
break
}
}
if !explicitlyAllowed {
for _, pattern := range t.denyPatterns {
if pattern.MatchString(lower) {
return "Command blocked by safety guard (dangerous pattern detected)"
}
}
}
@@ -287,16 +331,18 @@ func (t *ExecTool) guardCommand(command, cwd string) string {
return ""
}
pathPattern := regexp.MustCompile(`(?:^|\s|=)([A-Za-z]:\\[^\\"']+|/[a-zA-Z.][^\s"']*)`)
matches := pathPattern.FindAllStringSubmatch(cmd, -1)
matches := absolutePathPattern.FindAllString(cmd, -1)
for _, match := range matches {
raw := match[1]
for _, raw := range matches {
p, err := filepath.Abs(raw)
if err != nil {
continue
}
if safePaths[p] {
continue
}
rel, err := filepath.Rel(cwdPath, p)
if err != nil {
continue
+106 -25
View File
@@ -7,6 +7,8 @@ import (
"strings"
"testing"
"time"
"github.com/sipeed/picoclaw/pkg/config"
)
// TestShellTool_Success verifies successful command execution
@@ -310,6 +312,60 @@ func TestShellTool_RestrictToWorkspace(t *testing.T) {
}
}
// TestShellTool_DevNullAllowed verifies that /dev/null redirections are not blocked (issue #964).
func TestShellTool_DevNullAllowed(t *testing.T) {
tmpDir := t.TempDir()
tool, err := NewExecTool(tmpDir, true)
if err != nil {
t.Fatalf("unable to configure exec tool: %s", err)
}
commands := []string{
"echo hello 2>/dev/null",
"echo hello >/dev/null",
"echo hello > /dev/null",
"echo hello 2> /dev/null",
"echo hello >/dev/null 2>&1",
"find " + tmpDir + " -name '*.go' 2>/dev/null",
}
for _, cmd := range commands {
result := tool.Execute(context.Background(), map[string]any{"command": cmd})
if result.IsError && strings.Contains(result.ForLLM, "blocked") {
t.Errorf("command should not be blocked: %s\n error: %s", cmd, result.ForLLM)
}
}
}
// TestShellTool_BlockDevices verifies that writes to block devices are blocked (issue #965).
func TestShellTool_BlockDevices(t *testing.T) {
tool, err := NewExecTool("", false)
if err != nil {
t.Fatalf("unable to configure exec tool: %s", err)
}
blocked := []string{
"echo x > /dev/sda",
"echo x > /dev/hda",
"echo x > /dev/vda",
"echo x > /dev/xvda",
"echo x > /dev/nvme0n1",
"echo x > /dev/mmcblk0",
"echo x > /dev/loop0",
"echo x > /dev/dm-0",
"echo x > /dev/md0",
"echo x > /dev/sr0",
"echo x > /dev/nbd0",
}
for _, cmd := range blocked {
result := tool.Execute(context.Background(), map[string]any{"command": cmd})
if !result.IsError {
t.Errorf("expected block device write to be blocked: %s", cmd)
}
}
}
// TestShellTool_DenyPattern_DiskWiping verifies the deny pattern for disk wiping
// commands (format, mkfs, diskpart) blocks them when preceded by shell separators
// but does NOT block legitimate uses like --format flags.
@@ -322,7 +378,7 @@ func TestShellTool_DenyPattern_DiskWiping(t *testing.T) {
ctx := context.Background()
// These should be BLOCKED (disk wiping commands)
blocked := []struct {
blockedCmds := []struct {
name string
cmd string
}{
@@ -334,7 +390,7 @@ func TestShellTool_DenyPattern_DiskWiping(t *testing.T) {
{"diskpart standalone", "diskpart /s script.txt"},
}
for _, tt := range blocked {
for _, tt := range blockedCmds {
t.Run("blocked_"+tt.name, func(t *testing.T) {
result := tool.Execute(ctx, map[string]any{"command": tt.cmd})
if !result.IsError {
@@ -362,35 +418,60 @@ func TestShellTool_DenyPattern_DiskWiping(t *testing.T) {
}
}
// TestShellTool_RestrictToWorkspace_HiddenDirs verifies that hidden directory
// paths (starting with .) are properly detected by the workspace guard.
func TestShellTool_RestrictToWorkspace_HiddenDirs(t *testing.T) {
// TestShellTool_SafePathsInWorkspaceRestriction verifies that safe kernel pseudo-devices
// are allowed even when workspace restriction is active.
func TestShellTool_SafePathsInWorkspaceRestriction(t *testing.T) {
tmpDir := t.TempDir()
tool, err := NewExecTool(tmpDir, false)
tool, err := NewExecTool(tmpDir, true)
if err != nil {
t.Fatalf("unable to configure exec tool: %s", err)
}
tool.SetRestrictToWorkspace(true)
ctx := context.Background()
// Reading a hidden dir outside workspace should be blocked
result := tool.Execute(ctx, map[string]any{
"command": "cat /.ssh/config",
})
if !result.IsError {
t.Errorf("Expected /.ssh/config to be blocked with restrictToWorkspace=true")
// These reference paths outside workspace but should be allowed via safePaths.
commands := []string{
"cat /dev/urandom | head -c 16 | od",
"echo test > /dev/null",
"dd if=/dev/zero bs=1 count=1",
}
// Flag-attached paths outside workspace should be blocked
result2 := tool.Execute(ctx, map[string]any{
"command": "grep --include=/etc/passwd pattern",
})
if !result2.IsError {
// This tests the = delimiter fix; --include=/etc/passwd uses = in real
// usage but --include /etc/passwd uses space. Both patterns should catch it.
// If this specific form isn't blocked, it's acceptable since the primary
// concern is the = form (--file=/etc/passwd).
_ = result2 // acceptable either way for this pattern variant
for _, cmd := range commands {
result := tool.Execute(context.Background(), map[string]any{"command": cmd})
if result.IsError && strings.Contains(result.ForLLM, "path outside working dir") {
t.Errorf("safe path should not be blocked by workspace check: %s\n error: %s", cmd, result.ForLLM)
}
}
}
// TestShellTool_CustomAllowPatterns verifies that custom allow patterns exempt
// commands from deny pattern checks.
func TestShellTool_CustomAllowPatterns(t *testing.T) {
cfg := &config.Config{
Tools: config.ToolsConfig{
Exec: config.ExecConfig{
EnableDenyPatterns: true,
CustomAllowPatterns: []string{`\bgit\s+push\s+origin\b`},
},
},
}
tool, err := NewExecToolWithConfig("", false, cfg)
if err != nil {
t.Fatalf("unable to configure exec tool: %s", err)
}
// "git push origin main" should be allowed by custom allow pattern.
result := tool.Execute(context.Background(), map[string]any{
"command": "git push origin main",
})
if result.IsError && strings.Contains(result.ForLLM, "blocked") {
t.Errorf("custom allow pattern should exempt 'git push origin main', got: %s", result.ForLLM)
}
// "git push upstream main" should still be blocked (does not match allow pattern).
result = tool.Execute(context.Background(), map[string]any{
"command": "git push upstream main",
})
if !result.IsError {
t.Errorf("'git push upstream main' should still be blocked by deny pattern")
}
}
+86 -61
View File
@@ -4,6 +4,7 @@ import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
@@ -15,6 +16,14 @@ import (
const (
userAgent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36"
// HTTP client timeouts for web tool providers.
searchTimeout = 10 * time.Second // Brave, Tavily, DuckDuckGo
perplexityTimeout = 30 * time.Second // Perplexity (LLM-based, slower)
fetchTimeout = 60 * time.Second // WebFetchTool
defaultMaxChars = 50000
maxRedirects = 5
)
// Pre-compiled regexes for HTML text extraction
@@ -74,6 +83,7 @@ type SearchProvider interface {
type BraveSearchProvider struct {
apiKey string
proxy string
client *http.Client
}
func (p *BraveSearchProvider) Search(ctx context.Context, query string, count int) (string, error) {
@@ -88,11 +98,7 @@ func (p *BraveSearchProvider) Search(ctx context.Context, query string, count in
req.Header.Set("Accept", "application/json")
req.Header.Set("X-Subscription-Token", p.apiKey)
client, err := createHTTPClient(p.proxy, 10*time.Second)
if err != nil {
return "", fmt.Errorf("failed to create HTTP client: %w", err)
}
resp, err := client.Do(req)
resp, err := p.client.Do(req)
if err != nil {
return "", fmt.Errorf("request failed: %w", err)
}
@@ -103,6 +109,10 @@ func (p *BraveSearchProvider) Search(ctx context.Context, query string, count in
return "", fmt.Errorf("failed to read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("brave api error (status %d): %s", resp.StatusCode, string(body))
}
var searchResp struct {
Web struct {
Results []struct {
@@ -143,6 +153,7 @@ type TavilySearchProvider struct {
apiKey string
baseURL string
proxy string
client *http.Client
}
func (p *TavilySearchProvider) Search(ctx context.Context, query string, count int) (string, error) {
@@ -174,11 +185,7 @@ func (p *TavilySearchProvider) Search(ctx context.Context, query string, count i
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", userAgent)
client, err := createHTTPClient(p.proxy, 10*time.Second)
if err != nil {
return "", fmt.Errorf("failed to create HTTP client: %w", err)
}
resp, err := client.Do(req)
resp, err := p.client.Do(req)
if err != nil {
return "", fmt.Errorf("request failed: %w", err)
}
@@ -226,7 +233,8 @@ func (p *TavilySearchProvider) Search(ctx context.Context, query string, count i
}
type DuckDuckGoSearchProvider struct {
proxy string
proxy string
client *http.Client
}
func (p *DuckDuckGoSearchProvider) Search(ctx context.Context, query string, count int) (string, error) {
@@ -239,11 +247,7 @@ func (p *DuckDuckGoSearchProvider) Search(ctx context.Context, query string, cou
req.Header.Set("User-Agent", userAgent)
client, err := createHTTPClient(p.proxy, 10*time.Second)
if err != nil {
return "", fmt.Errorf("failed to create HTTP client: %w", err)
}
resp, err := client.Do(req)
resp, err := p.client.Do(req)
if err != nil {
return "", fmt.Errorf("request failed: %w", err)
}
@@ -285,7 +289,7 @@ func (p *DuckDuckGoSearchProvider) extractResults(html string, count int, query
maxItems := min(len(matches), count)
for i := 0; i < maxItems; i++ {
for i := range maxItems {
urlStr := matches[i][1]
title := stripTags(matches[i][2])
title = strings.TrimSpace(title)
@@ -293,9 +297,9 @@ func (p *DuckDuckGoSearchProvider) extractResults(html string, count int, query
// URL decoding if needed
if strings.Contains(urlStr, "uddg=") {
if u, err := url.QueryUnescape(urlStr); err == nil {
idx := strings.Index(u, "uddg=")
if idx != -1 {
urlStr = u[idx+5:]
_, after, ok := strings.Cut(u, "uddg=")
if ok {
urlStr = after
}
}
}
@@ -322,6 +326,7 @@ func stripTags(content string) string {
type PerplexitySearchProvider struct {
apiKey string
proxy string
client *http.Client
}
func (p *PerplexitySearchProvider) Search(ctx context.Context, query string, count int) (string, error) {
@@ -356,11 +361,7 @@ func (p *PerplexitySearchProvider) Search(ctx context.Context, query string, cou
req.Header.Set("Authorization", "Bearer "+p.apiKey)
req.Header.Set("User-Agent", userAgent)
client, err := createHTTPClient(p.proxy, 30*time.Second)
if err != nil {
return "", fmt.Errorf("failed to create HTTP client: %w", err)
}
resp, err := client.Do(req)
resp, err := p.client.Do(req)
if err != nil {
return "", fmt.Errorf("request failed: %w", err)
}
@@ -415,43 +416,60 @@ type WebSearchToolOptions struct {
Proxy string
}
func NewWebSearchTool(opts WebSearchToolOptions) *WebSearchTool {
func NewWebSearchTool(opts WebSearchToolOptions) (*WebSearchTool, error) {
var provider SearchProvider
maxResults := 5
// Priority: Perplexity > Brave > Tavily > DuckDuckGo
if opts.PerplexityEnabled && opts.PerplexityAPIKey != "" {
provider = &PerplexitySearchProvider{apiKey: opts.PerplexityAPIKey, proxy: opts.Proxy}
client, err := createHTTPClient(opts.Proxy, perplexityTimeout)
if err != nil {
return nil, fmt.Errorf("failed to create HTTP client for Perplexity: %w", err)
}
provider = &PerplexitySearchProvider{apiKey: opts.PerplexityAPIKey, proxy: opts.Proxy, client: client}
if opts.PerplexityMaxResults > 0 {
maxResults = opts.PerplexityMaxResults
}
} else if opts.BraveEnabled && opts.BraveAPIKey != "" {
provider = &BraveSearchProvider{apiKey: opts.BraveAPIKey, proxy: opts.Proxy}
client, err := createHTTPClient(opts.Proxy, searchTimeout)
if err != nil {
return nil, fmt.Errorf("failed to create HTTP client for Brave: %w", err)
}
provider = &BraveSearchProvider{apiKey: opts.BraveAPIKey, proxy: opts.Proxy, client: client}
if opts.BraveMaxResults > 0 {
maxResults = opts.BraveMaxResults
}
} else if opts.TavilyEnabled && opts.TavilyAPIKey != "" {
client, err := createHTTPClient(opts.Proxy, searchTimeout)
if err != nil {
return nil, fmt.Errorf("failed to create HTTP client for Tavily: %w", err)
}
provider = &TavilySearchProvider{
apiKey: opts.TavilyAPIKey,
baseURL: opts.TavilyBaseURL,
proxy: opts.Proxy,
client: client,
}
if opts.TavilyMaxResults > 0 {
maxResults = opts.TavilyMaxResults
}
} else if opts.DuckDuckGoEnabled {
provider = &DuckDuckGoSearchProvider{proxy: opts.Proxy}
client, err := createHTTPClient(opts.Proxy, searchTimeout)
if err != nil {
return nil, fmt.Errorf("failed to create HTTP client for DuckDuckGo: %w", err)
}
provider = &DuckDuckGoSearchProvider{proxy: opts.Proxy, client: client}
if opts.DuckDuckGoMaxResults > 0 {
maxResults = opts.DuckDuckGoMaxResults
}
} else {
return nil
return nil, nil
}
return &WebSearchTool{
provider: provider,
maxResults: maxResults,
}
}, nil
}
func (t *WebSearchTool) Name() string {
@@ -506,27 +524,40 @@ func (t *WebSearchTool) Execute(ctx context.Context, args map[string]any) *ToolR
}
type WebFetchTool struct {
maxChars int
proxy string
maxChars int
proxy string
client *http.Client
fetchLimitBytes int64
}
func NewWebFetchTool(maxChars int) *WebFetchTool {
if maxChars <= 0 {
maxChars = 50000
}
return &WebFetchTool{
maxChars: maxChars,
}
func NewWebFetchTool(maxChars int, fetchLimitBytes int64) (*WebFetchTool, error) {
// createHTTPClient cannot fail with an empty proxy string.
return NewWebFetchToolWithProxy(maxChars, "", fetchLimitBytes)
}
func NewWebFetchToolWithProxy(maxChars int, proxy string) *WebFetchTool {
func NewWebFetchToolWithProxy(maxChars int, proxy string, fetchLimitBytes int64) (*WebFetchTool, error) {
if maxChars <= 0 {
maxChars = 50000
maxChars = defaultMaxChars
}
client, err := createHTTPClient(proxy, fetchTimeout)
if err != nil {
return nil, fmt.Errorf("failed to create HTTP client for web fetch: %w", err)
}
client.CheckRedirect = func(req *http.Request, via []*http.Request) error {
if len(via) >= maxRedirects {
return fmt.Errorf("stopped after %d redirects", maxRedirects)
}
return nil
}
if fetchLimitBytes <= 0 {
fetchLimitBytes = 10 * 1024 * 1024 // Security Fallback
}
return &WebFetchTool{
maxChars: maxChars,
proxy: proxy,
}
maxChars: maxChars,
proxy: proxy,
client: client,
fetchLimitBytes: fetchLimitBytes,
}, nil
}
func (t *WebFetchTool) Name() string {
@@ -588,27 +619,21 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]any) *ToolRe
req.Header.Set("User-Agent", userAgent)
client, err := createHTTPClient(t.proxy, 60*time.Second)
if err != nil {
return ErrorResult(fmt.Sprintf("failed to create HTTP client: %v", err))
}
// Configure redirect handling
client.CheckRedirect = func(req *http.Request, via []*http.Request) error {
if len(via) >= 5 {
return fmt.Errorf("stopped after 5 redirects")
}
return nil
}
resp, err := client.Do(req)
resp, err := t.client.Do(req)
if err != nil {
return ErrorResult(fmt.Sprintf("request failed: %v", err))
}
resp.Body = http.MaxBytesReader(nil, resp.Body, t.fetchLimitBytes)
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
var maxBytesErr *http.MaxBytesError
if errors.As(err, &maxBytesErr) {
return ErrorResult(fmt.Sprintf("failed to read response: size exceeded %d bytes limit", t.fetchLimitBytes))
}
return ErrorResult(fmt.Sprintf("failed to read response: %v", err))
}
@@ -652,14 +677,14 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]any) *ToolRe
resultJSON, _ := json.MarshalIndent(result, "", " ")
return &ToolResult{
ForLLM: fmt.Sprintf(
ForLLM: string(resultJSON),
ForUser: fmt.Sprintf(
"Fetched %d bytes from %s (extractor: %s, truncated: %v)",
len(text),
urlStr,
extractor,
truncated,
),
ForUser: string(resultJSON),
}
}
+144 -35
View File
@@ -1,15 +1,21 @@
package tools
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/sipeed/picoclaw/pkg/logger"
)
const testFetchLimit = int64(10 * 1024 * 1024)
// TestWebTool_WebFetch_Success verifies successful URL fetching
func TestWebTool_WebFetch_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -19,7 +25,11 @@ func TestWebTool_WebFetch_Success(t *testing.T) {
}))
defer server.Close()
tool := NewWebFetchTool(50000)
tool, err := NewWebFetchTool(50000, testFetchLimit)
if err != nil {
t.Fatalf("Failed to create web fetch tool: %v", err)
}
ctx := context.Background()
args := map[string]any{
"url": server.URL,
@@ -32,14 +42,14 @@ func TestWebTool_WebFetch_Success(t *testing.T) {
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
}
// ForUser should contain the fetched content
if !strings.Contains(result.ForUser, "Test Page") {
t.Errorf("Expected ForUser to contain 'Test Page', got: %s", result.ForUser)
// ForLLM should contain the fetched content (full JSON result)
if !strings.Contains(result.ForLLM, "Test Page") {
t.Errorf("Expected ForLLM to contain 'Test Page', got: %s", result.ForLLM)
}
// ForLLM should contain summary
if !strings.Contains(result.ForLLM, "bytes") && !strings.Contains(result.ForLLM, "extractor") {
t.Errorf("Expected ForLLM to contain summary, got: %s", result.ForLLM)
// ForUser should contain summary
if !strings.Contains(result.ForUser, "bytes") && !strings.Contains(result.ForUser, "extractor") {
t.Errorf("Expected ForUser to contain summary, got: %s", result.ForUser)
}
}
@@ -55,7 +65,11 @@ func TestWebTool_WebFetch_JSON(t *testing.T) {
}))
defer server.Close()
tool := NewWebFetchTool(50000)
tool, err := NewWebFetchTool(50000, testFetchLimit)
if err != nil {
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
}
ctx := context.Background()
args := map[string]any{
"url": server.URL,
@@ -68,15 +82,19 @@ func TestWebTool_WebFetch_JSON(t *testing.T) {
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
}
// ForUser should contain formatted JSON
if !strings.Contains(result.ForUser, "key") && !strings.Contains(result.ForUser, "value") {
t.Errorf("Expected ForUser to contain JSON data, got: %s", result.ForUser)
// ForLLM should contain formatted JSON
if !strings.Contains(result.ForLLM, "key") && !strings.Contains(result.ForLLM, "value") {
t.Errorf("Expected ForLLM to contain JSON data, got: %s", result.ForLLM)
}
}
// TestWebTool_WebFetch_InvalidURL verifies error handling for invalid URL
func TestWebTool_WebFetch_InvalidURL(t *testing.T) {
tool := NewWebFetchTool(50000)
tool, err := NewWebFetchTool(50000, testFetchLimit)
if err != nil {
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
}
ctx := context.Background()
args := map[string]any{
"url": "not-a-valid-url",
@@ -97,7 +115,11 @@ func TestWebTool_WebFetch_InvalidURL(t *testing.T) {
// TestWebTool_WebFetch_UnsupportedScheme verifies error handling for non-http URLs
func TestWebTool_WebFetch_UnsupportedScheme(t *testing.T) {
tool := NewWebFetchTool(50000)
tool, err := NewWebFetchTool(50000, testFetchLimit)
if err != nil {
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
}
ctx := context.Background()
args := map[string]any{
"url": "ftp://example.com/file.txt",
@@ -118,7 +140,11 @@ func TestWebTool_WebFetch_UnsupportedScheme(t *testing.T) {
// TestWebTool_WebFetch_MissingURL verifies error handling for missing URL
func TestWebTool_WebFetch_MissingURL(t *testing.T) {
tool := NewWebFetchTool(50000)
tool, err := NewWebFetchTool(50000, testFetchLimit)
if err != nil {
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
}
ctx := context.Background()
args := map[string]any{}
@@ -146,7 +172,11 @@ func TestWebTool_WebFetch_Truncation(t *testing.T) {
}))
defer server.Close()
tool := NewWebFetchTool(1000) // Limit to 1000 chars
tool, err := NewWebFetchTool(1000, testFetchLimit) // Limit to 1000 chars
if err != nil {
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
}
ctx := context.Background()
args := map[string]any{
"url": server.URL,
@@ -159,9 +189,9 @@ func TestWebTool_WebFetch_Truncation(t *testing.T) {
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
}
// ForUser should contain truncated content (not the full 20000 chars)
// ForLLM should contain truncated content (not the full 20000 chars)
resultMap := make(map[string]any)
json.Unmarshal([]byte(result.ForUser), &resultMap)
json.Unmarshal([]byte(result.ForLLM), &resultMap)
if text, ok := resultMap["text"].(string); ok {
if len(text) > 1100 { // Allow some margin
t.Errorf("Expected content to be truncated to ~1000 chars, got: %d", len(text))
@@ -174,15 +204,64 @@ func TestWebTool_WebFetch_Truncation(t *testing.T) {
}
}
func TestWebFetchTool_PayloadTooLarge(t *testing.T) {
// Create a mock HTTP server
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/html")
w.WriteHeader(http.StatusOK)
// Generate a payload intentionally larger than our limit.
// Limit: 10 * 1024 * 1024 (10MB). We generate 10MB + 100 bytes of the letter 'A'.
largeData := bytes.Repeat([]byte("A"), int(testFetchLimit)+100)
w.Write(largeData)
}))
// Ensure the server is shut down at the end of the test
defer ts.Close()
// Initialize the tool
tool, err := NewWebFetchTool(50000, testFetchLimit)
if err != nil {
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
}
// Prepare the arguments pointing to the URL of our local mock server
args := map[string]any{
"url": ts.URL,
}
// Execute the tool
ctx := context.Background()
result := tool.Execute(ctx, args)
// Assuming ErrorResult sets the ForLLM field with the error text.
if result == nil {
t.Fatal("expected a ToolResult, got nil")
}
// Search for the exact error string we set earlier in the Execute method
expectedErrorMsg := fmt.Sprintf("size exceeded %d bytes limit", testFetchLimit)
if !strings.Contains(result.ForLLM, expectedErrorMsg) && !strings.Contains(result.ForUser, expectedErrorMsg) {
t.Errorf("test failed: expected error %q, but got: %+v", expectedErrorMsg, result)
}
}
// TestWebTool_WebSearch_NoApiKey verifies that no tool is created when API key is missing
func TestWebTool_WebSearch_NoApiKey(t *testing.T) {
tool := NewWebSearchTool(WebSearchToolOptions{BraveEnabled: true, BraveAPIKey: ""})
tool, err := NewWebSearchTool(WebSearchToolOptions{BraveEnabled: true, BraveAPIKey: ""})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if tool != nil {
t.Errorf("Expected nil tool when Brave API key is empty")
}
// Also nil when nothing is enabled
tool = NewWebSearchTool(WebSearchToolOptions{})
tool, err = NewWebSearchTool(WebSearchToolOptions{})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if tool != nil {
t.Errorf("Expected nil tool when no provider is enabled")
}
@@ -190,7 +269,10 @@ func TestWebTool_WebSearch_NoApiKey(t *testing.T) {
// TestWebTool_WebSearch_MissingQuery verifies error handling for missing query
func TestWebTool_WebSearch_MissingQuery(t *testing.T) {
tool := NewWebSearchTool(WebSearchToolOptions{BraveEnabled: true, BraveAPIKey: "test-key", BraveMaxResults: 5})
tool, err := NewWebSearchTool(WebSearchToolOptions{BraveEnabled: true, BraveAPIKey: "test-key", BraveMaxResults: 5})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
ctx := context.Background()
args := map[string]any{}
@@ -215,7 +297,11 @@ func TestWebTool_WebFetch_HTMLExtraction(t *testing.T) {
}))
defer server.Close()
tool := NewWebFetchTool(50000)
tool, err := NewWebFetchTool(50000, testFetchLimit)
if err != nil {
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
}
ctx := context.Background()
args := map[string]any{
"url": server.URL,
@@ -228,14 +314,14 @@ func TestWebTool_WebFetch_HTMLExtraction(t *testing.T) {
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
}
// ForUser should contain extracted text (without script/style tags)
if !strings.Contains(result.ForUser, "Title") && !strings.Contains(result.ForUser, "Content") {
t.Errorf("Expected ForUser to contain extracted text, got: %s", result.ForUser)
// ForLLM should contain extracted text (without script/style tags)
if !strings.Contains(result.ForLLM, "Title") && !strings.Contains(result.ForLLM, "Content") {
t.Errorf("Expected ForLLM to contain extracted text, got: %s", result.ForLLM)
}
// Should NOT contain script or style tags
if strings.Contains(result.ForUser, "<script>") || strings.Contains(result.ForUser, "<style>") {
t.Errorf("Expected script/style tags to be removed, got: %s", result.ForUser)
// Should NOT contain script or style tags in ForLLM
if strings.Contains(result.ForLLM, "<script>") || strings.Contains(result.ForLLM, "<style>") {
t.Errorf("Expected script/style tags to be removed, got: %s", result.ForLLM)
}
}
@@ -316,7 +402,11 @@ func TestWebFetchTool_extractText(t *testing.T) {
// TestWebTool_WebFetch_MissingDomain verifies error handling for URL without domain
func TestWebTool_WebFetch_MissingDomain(t *testing.T) {
tool := NewWebFetchTool(50000)
tool, err := NewWebFetchTool(50000, testFetchLimit)
if err != nil {
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
}
ctx := context.Background()
args := map[string]any{
"url": "https://",
@@ -438,15 +528,22 @@ func TestCreateHTTPClient_ProxyFromEnvironmentWhenConfigEmpty(t *testing.T) {
}
func TestNewWebFetchToolWithProxy(t *testing.T) {
tool := NewWebFetchToolWithProxy(1024, "http://127.0.0.1:7890")
if tool.maxChars != 1024 {
tool, err := NewWebFetchToolWithProxy(1024, "http://127.0.0.1:7890", testFetchLimit)
if err != nil {
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
} else if tool.maxChars != 1024 {
t.Fatalf("maxChars = %d, want %d", tool.maxChars, 1024)
}
if tool.proxy != "http://127.0.0.1:7890" {
t.Fatalf("proxy = %q, want %q", tool.proxy, "http://127.0.0.1:7890")
}
tool = NewWebFetchToolWithProxy(0, "http://127.0.0.1:7890")
tool, err = NewWebFetchToolWithProxy(0, "http://127.0.0.1:7890", testFetchLimit)
if err != nil {
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
}
if tool.maxChars != 50000 {
t.Fatalf("default maxChars = %d, want %d", tool.maxChars, 50000)
}
@@ -454,12 +551,15 @@ func TestNewWebFetchToolWithProxy(t *testing.T) {
func TestNewWebSearchTool_PropagatesProxy(t *testing.T) {
t.Run("perplexity", func(t *testing.T) {
tool := NewWebSearchTool(WebSearchToolOptions{
tool, err := NewWebSearchTool(WebSearchToolOptions{
PerplexityEnabled: true,
PerplexityAPIKey: "k",
PerplexityMaxResults: 3,
Proxy: "http://127.0.0.1:7890",
})
if err != nil {
t.Fatalf("NewWebSearchTool() error: %v", err)
}
p, ok := tool.provider.(*PerplexitySearchProvider)
if !ok {
t.Fatalf("provider type = %T, want *PerplexitySearchProvider", tool.provider)
@@ -470,12 +570,15 @@ func TestNewWebSearchTool_PropagatesProxy(t *testing.T) {
})
t.Run("brave", func(t *testing.T) {
tool := NewWebSearchTool(WebSearchToolOptions{
tool, err := NewWebSearchTool(WebSearchToolOptions{
BraveEnabled: true,
BraveAPIKey: "k",
BraveMaxResults: 3,
Proxy: "http://127.0.0.1:7890",
})
if err != nil {
t.Fatalf("NewWebSearchTool() error: %v", err)
}
p, ok := tool.provider.(*BraveSearchProvider)
if !ok {
t.Fatalf("provider type = %T, want *BraveSearchProvider", tool.provider)
@@ -486,11 +589,14 @@ func TestNewWebSearchTool_PropagatesProxy(t *testing.T) {
})
t.Run("duckduckgo", func(t *testing.T) {
tool := NewWebSearchTool(WebSearchToolOptions{
tool, err := NewWebSearchTool(WebSearchToolOptions{
DuckDuckGoEnabled: true,
DuckDuckGoMaxResults: 3,
Proxy: "http://127.0.0.1:7890",
})
if err != nil {
t.Fatalf("NewWebSearchTool() error: %v", err)
}
p, ok := tool.provider.(*DuckDuckGoSearchProvider)
if !ok {
t.Fatalf("provider type = %T, want *DuckDuckGoSearchProvider", tool.provider)
@@ -542,12 +648,15 @@ func TestWebTool_TavilySearch_Success(t *testing.T) {
}))
defer server.Close()
tool := NewWebSearchTool(WebSearchToolOptions{
tool, err := NewWebSearchTool(WebSearchToolOptions{
TavilyEnabled: true,
TavilyAPIKey: "test-key",
TavilyBaseURL: server.URL,
TavilyMaxResults: 5,
})
if err != nil {
t.Fatalf("NewWebSearchTool() error: %v", err)
}
ctx := context.Background()
args := map[string]any{