mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
Merge upstream/main into fix/bugfixes
Resolve conflicts: - provider.go: keep upstream's serializeMessages (supersedes stripSystemParts) - provider_test.go: keep upstream's serializeMessages tests - loop_test.go: add slices import needed by upstream tests - shell.go: merge PR's --format deny fix with upstream's block device pattern, safePaths, and absolutePathPattern - shell_test.go: include tests from both branches Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
+5
-3
@@ -3,6 +3,7 @@ package tools
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -222,7 +223,8 @@ func (t *CronTool) listJobs() *ToolResult {
|
||||
return SilentResult("No scheduled jobs")
|
||||
}
|
||||
|
||||
result := "Scheduled jobs:\n"
|
||||
var result strings.Builder
|
||||
result.WriteString("Scheduled jobs:\n")
|
||||
for _, j := range jobs {
|
||||
var scheduleInfo string
|
||||
if j.Schedule.Kind == "every" && j.Schedule.EveryMS != nil {
|
||||
@@ -234,10 +236,10 @@ func (t *CronTool) listJobs() *ToolResult {
|
||||
} else {
|
||||
scheduleInfo = "unknown"
|
||||
}
|
||||
result += fmt.Sprintf("- %s (id: %s, %s)\n", j.Name, j.ID, scheduleInfo)
|
||||
result.WriteString(fmt.Sprintf("- %s (id: %s, %s)\n", j.Name, j.ID, scheduleInfo))
|
||||
}
|
||||
|
||||
return SilentResult(result)
|
||||
return SilentResult(result.String())
|
||||
}
|
||||
|
||||
func (t *CronTool) removeJob(args map[string]any) *ToolResult {
|
||||
|
||||
+11
-14
@@ -5,6 +5,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
@@ -15,14 +16,12 @@ type EditFileTool struct {
|
||||
}
|
||||
|
||||
// NewEditFileTool creates a new EditFileTool with optional directory restriction.
|
||||
func NewEditFileTool(workspace string, restrict bool) *EditFileTool {
|
||||
var fs fileSystem
|
||||
if restrict {
|
||||
fs = &sandboxFs{workspace: workspace}
|
||||
} else {
|
||||
fs = &hostFs{}
|
||||
func NewEditFileTool(workspace string, restrict bool, allowPaths ...[]*regexp.Regexp) *EditFileTool {
|
||||
var patterns []*regexp.Regexp
|
||||
if len(allowPaths) > 0 {
|
||||
patterns = allowPaths[0]
|
||||
}
|
||||
return &EditFileTool{fs: fs}
|
||||
return &EditFileTool{fs: buildFs(workspace, restrict, patterns)}
|
||||
}
|
||||
|
||||
func (t *EditFileTool) Name() string {
|
||||
@@ -80,14 +79,12 @@ type AppendFileTool struct {
|
||||
fs fileSystem
|
||||
}
|
||||
|
||||
func NewAppendFileTool(workspace string, restrict bool) *AppendFileTool {
|
||||
var fs fileSystem
|
||||
if restrict {
|
||||
fs = &sandboxFs{workspace: workspace}
|
||||
} else {
|
||||
fs = &hostFs{}
|
||||
func NewAppendFileTool(workspace string, restrict bool, allowPaths ...[]*regexp.Regexp) *AppendFileTool {
|
||||
var patterns []*regexp.Regexp
|
||||
if len(allowPaths) > 0 {
|
||||
patterns = allowPaths[0]
|
||||
}
|
||||
return &AppendFileTool{fs: fs}
|
||||
return &AppendFileTool{fs: buildFs(workspace, restrict, patterns)}
|
||||
}
|
||||
|
||||
func (t *AppendFileTool) Name() string {
|
||||
|
||||
+67
-21
@@ -6,6 +6,7 @@ import (
|
||||
"io/fs"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -87,14 +88,12 @@ type ReadFileTool struct {
|
||||
fs fileSystem
|
||||
}
|
||||
|
||||
func NewReadFileTool(workspace string, restrict bool) *ReadFileTool {
|
||||
var fs fileSystem
|
||||
if restrict {
|
||||
fs = &sandboxFs{workspace: workspace}
|
||||
} else {
|
||||
fs = &hostFs{}
|
||||
func NewReadFileTool(workspace string, restrict bool, allowPaths ...[]*regexp.Regexp) *ReadFileTool {
|
||||
var patterns []*regexp.Regexp
|
||||
if len(allowPaths) > 0 {
|
||||
patterns = allowPaths[0]
|
||||
}
|
||||
return &ReadFileTool{fs: fs}
|
||||
return &ReadFileTool{fs: buildFs(workspace, restrict, patterns)}
|
||||
}
|
||||
|
||||
func (t *ReadFileTool) Name() string {
|
||||
@@ -135,14 +134,12 @@ type WriteFileTool struct {
|
||||
fs fileSystem
|
||||
}
|
||||
|
||||
func NewWriteFileTool(workspace string, restrict bool) *WriteFileTool {
|
||||
var fs fileSystem
|
||||
if restrict {
|
||||
fs = &sandboxFs{workspace: workspace}
|
||||
} else {
|
||||
fs = &hostFs{}
|
||||
func NewWriteFileTool(workspace string, restrict bool, allowPaths ...[]*regexp.Regexp) *WriteFileTool {
|
||||
var patterns []*regexp.Regexp
|
||||
if len(allowPaths) > 0 {
|
||||
patterns = allowPaths[0]
|
||||
}
|
||||
return &WriteFileTool{fs: fs}
|
||||
return &WriteFileTool{fs: buildFs(workspace, restrict, patterns)}
|
||||
}
|
||||
|
||||
func (t *WriteFileTool) Name() string {
|
||||
@@ -192,14 +189,12 @@ type ListDirTool struct {
|
||||
fs fileSystem
|
||||
}
|
||||
|
||||
func NewListDirTool(workspace string, restrict bool) *ListDirTool {
|
||||
var fs fileSystem
|
||||
if restrict {
|
||||
fs = &sandboxFs{workspace: workspace}
|
||||
} else {
|
||||
fs = &hostFs{}
|
||||
func NewListDirTool(workspace string, restrict bool, allowPaths ...[]*regexp.Regexp) *ListDirTool {
|
||||
var patterns []*regexp.Regexp
|
||||
if len(allowPaths) > 0 {
|
||||
patterns = allowPaths[0]
|
||||
}
|
||||
return &ListDirTool{fs: fs}
|
||||
return &ListDirTool{fs: buildFs(workspace, restrict, patterns)}
|
||||
}
|
||||
|
||||
func (t *ListDirTool) Name() string {
|
||||
@@ -394,6 +389,57 @@ func (r *sandboxFs) ReadDir(path string) ([]os.DirEntry, error) {
|
||||
return entries, err
|
||||
}
|
||||
|
||||
// whitelistFs wraps a sandboxFs and allows access to specific paths outside
|
||||
// the workspace when they match any of the provided patterns.
|
||||
type whitelistFs struct {
|
||||
sandbox *sandboxFs
|
||||
host hostFs
|
||||
patterns []*regexp.Regexp
|
||||
}
|
||||
|
||||
func (w *whitelistFs) matches(path string) bool {
|
||||
for _, p := range w.patterns {
|
||||
if p.MatchString(path) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (w *whitelistFs) ReadFile(path string) ([]byte, error) {
|
||||
if w.matches(path) {
|
||||
return w.host.ReadFile(path)
|
||||
}
|
||||
return w.sandbox.ReadFile(path)
|
||||
}
|
||||
|
||||
func (w *whitelistFs) WriteFile(path string, data []byte) error {
|
||||
if w.matches(path) {
|
||||
return w.host.WriteFile(path, data)
|
||||
}
|
||||
return w.sandbox.WriteFile(path, data)
|
||||
}
|
||||
|
||||
func (w *whitelistFs) ReadDir(path string) ([]os.DirEntry, error) {
|
||||
if w.matches(path) {
|
||||
return w.host.ReadDir(path)
|
||||
}
|
||||
return w.sandbox.ReadDir(path)
|
||||
}
|
||||
|
||||
// buildFs returns the appropriate fileSystem implementation based on restriction
|
||||
// settings and optional path whitelist patterns.
|
||||
func buildFs(workspace string, restrict bool, patterns []*regexp.Regexp) fileSystem {
|
||||
if !restrict {
|
||||
return &hostFs{}
|
||||
}
|
||||
sandbox := &sandboxFs{workspace: workspace}
|
||||
if len(patterns) > 0 {
|
||||
return &whitelistFs{sandbox: sandbox, patterns: patterns}
|
||||
}
|
||||
return sandbox
|
||||
}
|
||||
|
||||
// Helper to get a safe relative path for os.Root usage
|
||||
func getSafeRelPath(workspace, path string) (string, error) {
|
||||
if workspace == "" {
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
@@ -486,3 +487,36 @@ func TestRootRW_Write(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, newData, content)
|
||||
}
|
||||
|
||||
// TestWhitelistFs_AllowsMatchingPaths verifies that whitelistFs allows access to
|
||||
// paths matching the whitelist patterns while blocking non-matching paths.
|
||||
func TestWhitelistFs_AllowsMatchingPaths(t *testing.T) {
|
||||
workspace := t.TempDir()
|
||||
outsideDir := t.TempDir()
|
||||
outsideFile := filepath.Join(outsideDir, "allowed.txt")
|
||||
os.WriteFile(outsideFile, []byte("outside content"), 0o644)
|
||||
|
||||
// Pattern allows access to the outsideDir.
|
||||
patterns := []*regexp.Regexp{regexp.MustCompile(`^` + regexp.QuoteMeta(outsideDir))}
|
||||
|
||||
tool := NewReadFileTool(workspace, true, patterns)
|
||||
|
||||
// Read from whitelisted path should succeed.
|
||||
result := tool.Execute(context.Background(), map[string]any{"path": outsideFile})
|
||||
if result.IsError {
|
||||
t.Errorf("expected whitelisted path to be readable, got: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "outside content") {
|
||||
t.Errorf("expected file content, got: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
// Read from non-whitelisted path outside workspace should fail.
|
||||
otherDir := t.TempDir()
|
||||
otherFile := filepath.Join(otherDir, "blocked.txt")
|
||||
os.WriteFile(otherFile, []byte("blocked"), 0o644)
|
||||
|
||||
result = tool.Execute(context.Background(), map[string]any{"path": otherFile})
|
||||
if !result.IsError {
|
||||
t.Errorf("expected non-whitelisted path to be blocked, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,246 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"hash/fnv"
|
||||
"strings"
|
||||
|
||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
)
|
||||
|
||||
// MCPManager defines the interface for MCP manager operations
|
||||
// This allows for easier testing with mock implementations
|
||||
type MCPManager interface {
|
||||
CallTool(
|
||||
ctx context.Context,
|
||||
serverName, toolName string,
|
||||
arguments map[string]any,
|
||||
) (*mcp.CallToolResult, error)
|
||||
}
|
||||
|
||||
// MCPTool wraps an MCP tool to implement the Tool interface
|
||||
type MCPTool struct {
|
||||
manager MCPManager
|
||||
serverName string
|
||||
tool *mcp.Tool
|
||||
}
|
||||
|
||||
// NewMCPTool creates a new MCP tool wrapper
|
||||
func NewMCPTool(manager MCPManager, serverName string, tool *mcp.Tool) *MCPTool {
|
||||
return &MCPTool{
|
||||
manager: manager,
|
||||
serverName: serverName,
|
||||
tool: tool,
|
||||
}
|
||||
}
|
||||
|
||||
// sanitizeIdentifierComponent normalizes a string so it can be safely used
|
||||
// as part of a tool/function identifier for downstream providers.
|
||||
// It:
|
||||
// - lowercases the string
|
||||
// - replaces any character not in [a-z0-9_-] with '_'
|
||||
// - collapses multiple consecutive '_' into a single '_'
|
||||
// - trims leading/trailing '_'
|
||||
// - falls back to "unnamed" if the result is empty
|
||||
// - truncates overly long components to a reasonable length
|
||||
func sanitizeIdentifierComponent(s string) string {
|
||||
const maxLen = 64
|
||||
|
||||
s = strings.ToLower(s)
|
||||
var b strings.Builder
|
||||
b.Grow(len(s))
|
||||
|
||||
prevUnderscore := false
|
||||
for _, r := range s {
|
||||
isAllowed := (r >= 'a' && r <= 'z') ||
|
||||
(r >= '0' && r <= '9') ||
|
||||
r == '_' || r == '-'
|
||||
|
||||
if !isAllowed {
|
||||
// Normalize any disallowed character to '_'
|
||||
if !prevUnderscore {
|
||||
b.WriteRune('_')
|
||||
prevUnderscore = true
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if r == '_' {
|
||||
if prevUnderscore {
|
||||
continue
|
||||
}
|
||||
prevUnderscore = true
|
||||
} else {
|
||||
prevUnderscore = false
|
||||
}
|
||||
|
||||
b.WriteRune(r)
|
||||
}
|
||||
|
||||
result := strings.Trim(b.String(), "_")
|
||||
if result == "" {
|
||||
result = "unnamed"
|
||||
}
|
||||
|
||||
if len(result) > maxLen {
|
||||
result = result[:maxLen]
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// Name returns the tool name, prefixed with the server name.
|
||||
// The total length is capped at 64 characters (OpenAI-compatible API limit).
|
||||
// A short hash of the original (unsanitized) server and tool names is appended
|
||||
// whenever sanitization is lossy or the name is truncated, ensuring that two
|
||||
// names which differ only in disallowed characters remain distinct after sanitization.
|
||||
func (t *MCPTool) Name() string {
|
||||
// Prefix with server name to avoid conflicts, and sanitize components
|
||||
sanitizedServer := sanitizeIdentifierComponent(t.serverName)
|
||||
sanitizedTool := sanitizeIdentifierComponent(t.tool.Name)
|
||||
full := fmt.Sprintf("mcp_%s_%s", sanitizedServer, sanitizedTool)
|
||||
|
||||
// Check if sanitization was lossless (only lowercasing, no char replacement/truncation)
|
||||
lossless := strings.ToLower(t.serverName) == sanitizedServer &&
|
||||
strings.ToLower(t.tool.Name) == sanitizedTool
|
||||
|
||||
const maxTotal = 64
|
||||
if lossless && len(full) <= maxTotal {
|
||||
return full
|
||||
}
|
||||
|
||||
// Sanitization was lossy or name too long: append hash of the ORIGINAL names
|
||||
// (not the sanitized names) so different originals always yield different hashes.
|
||||
h := fnv.New32a()
|
||||
_, _ = h.Write([]byte(t.serverName + "\x00" + t.tool.Name))
|
||||
suffix := fmt.Sprintf("%08x", h.Sum32()) // 8 chars
|
||||
|
||||
base := full
|
||||
if len(base) > maxTotal-9 {
|
||||
base = strings.TrimRight(full[:maxTotal-9], "_")
|
||||
}
|
||||
return base + "_" + suffix
|
||||
}
|
||||
|
||||
// Description returns the tool description
|
||||
func (t *MCPTool) Description() string {
|
||||
desc := t.tool.Description
|
||||
if desc == "" {
|
||||
desc = fmt.Sprintf("MCP tool from %s server", t.serverName)
|
||||
}
|
||||
// Add server info to description
|
||||
return fmt.Sprintf("[MCP:%s] %s", t.serverName, desc)
|
||||
}
|
||||
|
||||
// Parameters returns the tool parameters schema
|
||||
func (t *MCPTool) Parameters() map[string]any {
|
||||
// The InputSchema is already a JSON Schema object
|
||||
schema := t.tool.InputSchema
|
||||
|
||||
// Handle nil schema
|
||||
if schema == nil {
|
||||
return map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{},
|
||||
"required": []string{},
|
||||
}
|
||||
}
|
||||
|
||||
// Try direct conversion first (fast path)
|
||||
if schemaMap, ok := schema.(map[string]any); ok {
|
||||
return schemaMap
|
||||
}
|
||||
|
||||
// Handle json.RawMessage and []byte - unmarshal directly
|
||||
var jsonData []byte
|
||||
if rawMsg, ok := schema.(json.RawMessage); ok {
|
||||
jsonData = rawMsg
|
||||
} else if bytes, ok := schema.([]byte); ok {
|
||||
jsonData = bytes
|
||||
}
|
||||
|
||||
if jsonData != nil {
|
||||
var result map[string]any
|
||||
if err := json.Unmarshal(jsonData, &result); err == nil {
|
||||
return result
|
||||
}
|
||||
// Fallback on error
|
||||
return map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{},
|
||||
"required": []string{},
|
||||
}
|
||||
}
|
||||
|
||||
// For other types (structs, etc.), convert via JSON marshal/unmarshal
|
||||
var err error
|
||||
jsonData, err = json.Marshal(schema)
|
||||
if err != nil {
|
||||
// Fallback to empty schema if marshaling fails
|
||||
return map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{},
|
||||
"required": []string{},
|
||||
}
|
||||
}
|
||||
|
||||
var result map[string]any
|
||||
if err := json.Unmarshal(jsonData, &result); err != nil {
|
||||
// Fallback to empty schema if unmarshaling fails
|
||||
return map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{},
|
||||
"required": []string{},
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// Execute executes the MCP tool
|
||||
func (t *MCPTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
|
||||
result, err := t.manager.CallTool(ctx, t.serverName, t.tool.Name, args)
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("MCP tool execution failed: %v", err)).WithError(err)
|
||||
}
|
||||
|
||||
if result == nil {
|
||||
nilErr := fmt.Errorf("MCP tool returned nil result without error")
|
||||
return ErrorResult("MCP tool execution failed: nil result").WithError(nilErr)
|
||||
}
|
||||
|
||||
// Handle error result from server
|
||||
if result.IsError {
|
||||
errMsg := extractContentText(result.Content)
|
||||
return ErrorResult(fmt.Sprintf("MCP tool returned error: %s", errMsg)).
|
||||
WithError(fmt.Errorf("MCP tool error: %s", errMsg))
|
||||
}
|
||||
|
||||
// Extract text content from result
|
||||
output := extractContentText(result.Content)
|
||||
|
||||
return &ToolResult{
|
||||
ForLLM: output,
|
||||
IsError: false,
|
||||
}
|
||||
}
|
||||
|
||||
// extractContentText extracts text from MCP content array
|
||||
func extractContentText(content []mcp.Content) string {
|
||||
var parts []string
|
||||
for _, c := range content {
|
||||
switch v := c.(type) {
|
||||
case *mcp.TextContent:
|
||||
parts = append(parts, v.Text)
|
||||
case *mcp.ImageContent:
|
||||
// For images, just indicate that an image was returned
|
||||
parts = append(parts, fmt.Sprintf("[Image: %s]", v.MIMEType))
|
||||
default:
|
||||
// For other content types, use string representation
|
||||
parts = append(parts, fmt.Sprintf("[Content: %T]", v))
|
||||
}
|
||||
}
|
||||
return strings.Join(parts, "\n")
|
||||
}
|
||||
@@ -0,0 +1,492 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
)
|
||||
|
||||
// MockMCPManager is a mock implementation of MCPManager interface for testing
|
||||
type MockMCPManager struct {
|
||||
callToolFunc func(ctx context.Context, serverName, toolName string, arguments map[string]any) (*mcp.CallToolResult, error)
|
||||
}
|
||||
|
||||
func (m *MockMCPManager) CallTool(
|
||||
ctx context.Context,
|
||||
serverName, toolName string,
|
||||
arguments map[string]any,
|
||||
) (*mcp.CallToolResult, error) {
|
||||
if m.callToolFunc != nil {
|
||||
return m.callToolFunc(ctx, serverName, toolName, arguments)
|
||||
}
|
||||
return &mcp.CallToolResult{
|
||||
Content: []mcp.Content{
|
||||
&mcp.TextContent{Text: "mock result"},
|
||||
},
|
||||
IsError: false,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// TestNewMCPTool verifies MCP tool creation
|
||||
func TestNewMCPTool(t *testing.T) {
|
||||
manager := &MockMCPManager{}
|
||||
tool := &mcp.Tool{
|
||||
Name: "test_tool",
|
||||
Description: "A test tool",
|
||||
InputSchema: map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"input": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Test input",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
mcpTool := NewMCPTool(manager, "test_server", tool)
|
||||
|
||||
if mcpTool == nil {
|
||||
t.Fatal("NewMCPTool should not return nil")
|
||||
}
|
||||
// Verify tool properties we can access
|
||||
if mcpTool.Name() != "mcp_test_server_test_tool" {
|
||||
t.Errorf("Expected tool name with prefix, got '%s'", mcpTool.Name())
|
||||
}
|
||||
}
|
||||
|
||||
// TestMCPTool_Name verifies tool name with server prefix
|
||||
func TestMCPTool_Name(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
serverName string
|
||||
toolName string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "simple name",
|
||||
serverName: "github",
|
||||
toolName: "create_issue",
|
||||
expected: "mcp_github_create_issue",
|
||||
},
|
||||
{
|
||||
name: "filesystem server",
|
||||
serverName: "filesystem",
|
||||
toolName: "read_file",
|
||||
expected: "mcp_filesystem_read_file",
|
||||
},
|
||||
{
|
||||
name: "remote server",
|
||||
serverName: "remote-api",
|
||||
toolName: "fetch_data",
|
||||
expected: "mcp_remote-api_fetch_data",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
manager := &MockMCPManager{}
|
||||
tool := &mcp.Tool{Name: tt.toolName}
|
||||
mcpTool := NewMCPTool(manager, tt.serverName, tool)
|
||||
|
||||
result := mcpTool.Name()
|
||||
if result != tt.expected {
|
||||
t.Errorf("Expected name '%s', got '%s'", tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestMCPTool_Description verifies tool description generation
|
||||
func TestMCPTool_Description(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
serverName string
|
||||
toolDescription string
|
||||
expectContains []string
|
||||
}{
|
||||
{
|
||||
name: "with description",
|
||||
serverName: "github",
|
||||
toolDescription: "Create a GitHub issue",
|
||||
expectContains: []string{"[MCP:github]", "Create a GitHub issue"},
|
||||
},
|
||||
{
|
||||
name: "empty description",
|
||||
serverName: "filesystem",
|
||||
toolDescription: "",
|
||||
expectContains: []string{"[MCP:filesystem]", "MCP tool from filesystem server"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
manager := &MockMCPManager{}
|
||||
tool := &mcp.Tool{
|
||||
Name: "test_tool",
|
||||
Description: tt.toolDescription,
|
||||
}
|
||||
mcpTool := NewMCPTool(manager, tt.serverName, tool)
|
||||
|
||||
result := mcpTool.Description()
|
||||
|
||||
for _, expected := range tt.expectContains {
|
||||
if !strings.Contains(result, expected) {
|
||||
t.Errorf("Description should contain '%s', got: %s", expected, result)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestMCPTool_Parameters verifies parameter schema conversion
|
||||
func TestMCPTool_Parameters(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
inputSchema any
|
||||
expectType string
|
||||
checkProperty string
|
||||
expectProperty bool
|
||||
}{
|
||||
{
|
||||
name: "map schema",
|
||||
inputSchema: map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"query": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Search query",
|
||||
},
|
||||
},
|
||||
"required": []string{"query"},
|
||||
},
|
||||
expectType: "object",
|
||||
checkProperty: "query",
|
||||
expectProperty: true,
|
||||
},
|
||||
{
|
||||
name: "nil schema",
|
||||
inputSchema: nil,
|
||||
expectType: "object",
|
||||
expectProperty: false,
|
||||
},
|
||||
{
|
||||
name: "json.RawMessage schema",
|
||||
inputSchema: []byte(`{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"repo": {
|
||||
"type": "string",
|
||||
"description": "Repository name"
|
||||
},
|
||||
"stars": {
|
||||
"type": "integer",
|
||||
"description": "Minimum stars"
|
||||
}
|
||||
},
|
||||
"required": ["repo"]
|
||||
}`),
|
||||
expectType: "object",
|
||||
checkProperty: "repo",
|
||||
expectProperty: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
manager := &MockMCPManager{}
|
||||
tool := &mcp.Tool{
|
||||
Name: "test_tool",
|
||||
InputSchema: tt.inputSchema,
|
||||
}
|
||||
mcpTool := NewMCPTool(manager, "test_server", tool)
|
||||
|
||||
params := mcpTool.Parameters()
|
||||
|
||||
if params == nil {
|
||||
t.Fatal("Parameters should not be nil")
|
||||
}
|
||||
|
||||
if params["type"] != tt.expectType {
|
||||
t.Errorf("Expected type '%s', got '%v'", tt.expectType, params["type"])
|
||||
}
|
||||
|
||||
// Check if property exists when expected
|
||||
if tt.checkProperty != "" {
|
||||
properties, ok := params["properties"].(map[string]any)
|
||||
if !ok && tt.expectProperty {
|
||||
t.Errorf("Expected properties to be a map")
|
||||
return
|
||||
}
|
||||
if ok {
|
||||
_, hasProperty := properties[tt.checkProperty]
|
||||
if hasProperty != tt.expectProperty {
|
||||
t.Errorf("Expected property '%s' existence: %v, got: %v",
|
||||
tt.checkProperty, tt.expectProperty, hasProperty)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestMCPTool_Execute_Success tests successful tool execution
|
||||
func TestMCPTool_Execute_Success(t *testing.T) {
|
||||
manager := &MockMCPManager{
|
||||
callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]any) (*mcp.CallToolResult, error) {
|
||||
// Verify correct parameters passed
|
||||
if serverName != "github" {
|
||||
t.Errorf("Expected serverName 'github', got '%s'", serverName)
|
||||
}
|
||||
if toolName != "search_repos" {
|
||||
t.Errorf("Expected toolName 'search_repos', got '%s'", toolName)
|
||||
}
|
||||
|
||||
return &mcp.CallToolResult{
|
||||
Content: []mcp.Content{
|
||||
&mcp.TextContent{Text: "Found 3 repositories"},
|
||||
},
|
||||
IsError: false,
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
tool := &mcp.Tool{
|
||||
Name: "search_repos",
|
||||
Description: "Search GitHub repositories",
|
||||
}
|
||||
mcpTool := NewMCPTool(manager, "github", tool)
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]any{
|
||||
"query": "golang mcp",
|
||||
}
|
||||
|
||||
result := mcpTool.Execute(ctx, args)
|
||||
|
||||
if result == nil {
|
||||
t.Fatal("Result should not be nil")
|
||||
}
|
||||
if result.IsError {
|
||||
t.Errorf("Expected no error, got error: %s", result.ForLLM)
|
||||
}
|
||||
if result.ForLLM != "Found 3 repositories" {
|
||||
t.Errorf("Expected 'Found 3 repositories', got '%s'", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMCPTool_Execute_ManagerError tests execution when manager returns error
|
||||
func TestMCPTool_Execute_ManagerError(t *testing.T) {
|
||||
manager := &MockMCPManager{
|
||||
callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]any) (*mcp.CallToolResult, error) {
|
||||
return nil, fmt.Errorf("connection failed")
|
||||
},
|
||||
}
|
||||
|
||||
tool := &mcp.Tool{Name: "test_tool"}
|
||||
mcpTool := NewMCPTool(manager, "test_server", tool)
|
||||
|
||||
ctx := context.Background()
|
||||
result := mcpTool.Execute(ctx, map[string]any{})
|
||||
|
||||
if result == nil {
|
||||
t.Fatal("Result should not be nil")
|
||||
}
|
||||
if !result.IsError {
|
||||
t.Error("Expected IsError to be true")
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "MCP tool execution failed") {
|
||||
t.Errorf("Error message should mention execution failure, got: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "connection failed") {
|
||||
t.Errorf("Error message should include original error, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMCPTool_Execute_ServerError tests execution when server returns error
|
||||
func TestMCPTool_Execute_ServerError(t *testing.T) {
|
||||
manager := &MockMCPManager{
|
||||
callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]any) (*mcp.CallToolResult, error) {
|
||||
return &mcp.CallToolResult{
|
||||
Content: []mcp.Content{
|
||||
&mcp.TextContent{Text: "Invalid API key"},
|
||||
},
|
||||
IsError: true,
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
tool := &mcp.Tool{Name: "test_tool"}
|
||||
mcpTool := NewMCPTool(manager, "test_server", tool)
|
||||
|
||||
ctx := context.Background()
|
||||
result := mcpTool.Execute(ctx, map[string]any{})
|
||||
|
||||
if result == nil {
|
||||
t.Fatal("Result should not be nil")
|
||||
}
|
||||
if !result.IsError {
|
||||
t.Error("Expected IsError to be true")
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "MCP tool returned error") {
|
||||
t.Errorf("Error message should mention server error, got: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "Invalid API key") {
|
||||
t.Errorf("Error message should include server message, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMCPTool_Execute_MultipleContent tests execution with multiple content items
|
||||
func TestMCPTool_Execute_MultipleContent(t *testing.T) {
|
||||
manager := &MockMCPManager{
|
||||
callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]any) (*mcp.CallToolResult, error) {
|
||||
return &mcp.CallToolResult{
|
||||
Content: []mcp.Content{
|
||||
&mcp.TextContent{Text: "First line"},
|
||||
&mcp.TextContent{Text: "Second line"},
|
||||
&mcp.TextContent{Text: "Third line"},
|
||||
},
|
||||
IsError: false,
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
tool := &mcp.Tool{Name: "multi_output"}
|
||||
mcpTool := NewMCPTool(manager, "test_server", tool)
|
||||
|
||||
ctx := context.Background()
|
||||
result := mcpTool.Execute(ctx, map[string]any{})
|
||||
|
||||
if result.IsError {
|
||||
t.Errorf("Expected no error, got: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
expected := "First line\nSecond line\nThird line"
|
||||
if result.ForLLM != expected {
|
||||
t.Errorf("Expected '%s', got '%s'", expected, result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
// TestExtractContentText_TextContent tests text content extraction
|
||||
func TestExtractContentText_TextContent(t *testing.T) {
|
||||
content := []mcp.Content{
|
||||
&mcp.TextContent{Text: "Hello World"},
|
||||
&mcp.TextContent{Text: "Second message"},
|
||||
}
|
||||
|
||||
result := extractContentText(content)
|
||||
expected := "Hello World\nSecond message"
|
||||
|
||||
if result != expected {
|
||||
t.Errorf("Expected '%s', got '%s'", expected, result)
|
||||
}
|
||||
}
|
||||
|
||||
// TestExtractContentText_ImageContent tests image content extraction
|
||||
func TestExtractContentText_ImageContent(t *testing.T) {
|
||||
content := []mcp.Content{
|
||||
&mcp.ImageContent{
|
||||
Data: []byte("base64data"),
|
||||
MIMEType: "image/png",
|
||||
},
|
||||
}
|
||||
|
||||
result := extractContentText(content)
|
||||
|
||||
if !strings.Contains(result, "[Image:") {
|
||||
t.Errorf("Expected image indicator, got: %s", result)
|
||||
}
|
||||
if !strings.Contains(result, "image/png") {
|
||||
t.Errorf("Expected MIME type in output, got: %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
// TestExtractContentText_MixedContent tests mixed content types
|
||||
func TestExtractContentText_MixedContent(t *testing.T) {
|
||||
content := []mcp.Content{
|
||||
&mcp.TextContent{Text: "Description"},
|
||||
&mcp.ImageContent{
|
||||
Data: []byte("data"),
|
||||
MIMEType: "image/jpeg",
|
||||
},
|
||||
&mcp.TextContent{Text: "More text"},
|
||||
}
|
||||
|
||||
result := extractContentText(content)
|
||||
|
||||
if !strings.Contains(result, "Description") {
|
||||
t.Errorf("Should contain text content, got: %s", result)
|
||||
}
|
||||
if !strings.Contains(result, "[Image:") {
|
||||
t.Errorf("Should contain image indicator, got: %s", result)
|
||||
}
|
||||
if !strings.Contains(result, "More text") {
|
||||
t.Errorf("Should contain second text, got: %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
// TestExtractContentText_EmptyContent tests empty content array
|
||||
func TestExtractContentText_EmptyContent(t *testing.T) {
|
||||
content := []mcp.Content{}
|
||||
|
||||
result := extractContentText(content)
|
||||
|
||||
if result != "" {
|
||||
t.Errorf("Expected empty string for empty content, got: %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMCPTool_InterfaceCompliance verifies MCPTool implements Tool interface
|
||||
func TestMCPTool_InterfaceCompliance(t *testing.T) {
|
||||
manager := &MockMCPManager{}
|
||||
tool := &mcp.Tool{Name: "test"}
|
||||
mcpTool := NewMCPTool(manager, "test_server", tool)
|
||||
|
||||
// Verify it implements Tool interface
|
||||
var _ Tool = mcpTool
|
||||
}
|
||||
|
||||
// TestMCPTool_Parameters_MapSchema tests schema that's already a map
|
||||
func TestMCPTool_Parameters_MapSchema(t *testing.T) {
|
||||
manager := &MockMCPManager{}
|
||||
schema := map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"name": map[string]any{
|
||||
"type": "string",
|
||||
"description": "The name parameter",
|
||||
},
|
||||
},
|
||||
"required": []string{"name"},
|
||||
}
|
||||
|
||||
tool := &mcp.Tool{
|
||||
Name: "test_tool",
|
||||
InputSchema: schema,
|
||||
}
|
||||
mcpTool := NewMCPTool(manager, "test_server", tool)
|
||||
|
||||
params := mcpTool.Parameters()
|
||||
|
||||
// Should return the schema as-is when it's already a map
|
||||
if params["type"] != "object" {
|
||||
t.Errorf("Expected type 'object', got '%v'", params["type"])
|
||||
}
|
||||
|
||||
props, ok := params["properties"].(map[string]any)
|
||||
if !ok {
|
||||
t.Error("Properties should be a map")
|
||||
}
|
||||
|
||||
nameParam, ok := props["name"].(map[string]any)
|
||||
if !ok {
|
||||
t.Error("Name parameter should exist")
|
||||
}
|
||||
|
||||
if nameParam["type"] != "string" {
|
||||
t.Errorf("Name type should be 'string', got '%v'", nameParam["type"])
|
||||
}
|
||||
}
|
||||
@@ -25,7 +25,12 @@ func NewToolRegistry() *ToolRegistry {
|
||||
func (r *ToolRegistry) Register(tool Tool) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.tools[tool.Name()] = tool
|
||||
name := tool.Name()
|
||||
if _, exists := r.tools[name]; exists {
|
||||
logger.WarnCF("tools", "Tool registration overwrites existing tool",
|
||||
map[string]any{"name": name})
|
||||
}
|
||||
r.tools[name] = tool
|
||||
}
|
||||
|
||||
func (r *ToolRegistry) Get(name string) (Tool, bool) {
|
||||
|
||||
@@ -329,7 +329,7 @@ func TestToolRegistry_ConcurrentAccess(t *testing.T) {
|
||||
r := NewToolRegistry()
|
||||
var wg sync.WaitGroup
|
||||
|
||||
for i := 0; i < 50; i++ {
|
||||
for i := range 50 {
|
||||
wg.Add(1)
|
||||
go func(n int) {
|
||||
defer wg.Done()
|
||||
|
||||
+96
-50
@@ -21,53 +21,77 @@ type ExecTool struct {
|
||||
timeout time.Duration
|
||||
denyPatterns []*regexp.Regexp
|
||||
allowPatterns []*regexp.Regexp
|
||||
customAllowPatterns []*regexp.Regexp
|
||||
restrictToWorkspace bool
|
||||
}
|
||||
|
||||
var defaultDenyPatterns = []*regexp.Regexp{
|
||||
regexp.MustCompile(`\brm\s+-[rf]{1,2}\b`),
|
||||
regexp.MustCompile(`\bdel\s+/[fq]\b`),
|
||||
regexp.MustCompile(`\brmdir\s+/s\b`),
|
||||
regexp.MustCompile(`(?:^|[;&|]\s*|\s+)(format|mkfs|diskpart)\s`), // Match disk wiping commands, avoid matching --format flags
|
||||
regexp.MustCompile(`\bdd\s+if=`),
|
||||
regexp.MustCompile(`>\s*/dev/sd[a-z]\b`), // Block writes to disk devices (but allow /dev/null)
|
||||
regexp.MustCompile(`\b(shutdown|reboot|poweroff)\b`),
|
||||
regexp.MustCompile(`:\(\)\s*\{.*\};\s*:`),
|
||||
regexp.MustCompile(`\$\([^)]+\)`),
|
||||
regexp.MustCompile(`\$\{[^}]+\}`),
|
||||
regexp.MustCompile("`[^`]+`"),
|
||||
regexp.MustCompile(`\|\s*sh\b`),
|
||||
regexp.MustCompile(`\|\s*bash\b`),
|
||||
regexp.MustCompile(`;\s*rm\s+-[rf]`),
|
||||
regexp.MustCompile(`&&\s*rm\s+-[rf]`),
|
||||
regexp.MustCompile(`\|\|\s*rm\s+-[rf]`),
|
||||
regexp.MustCompile(`>\s*/dev/null\s*>&?\s*\d?`),
|
||||
regexp.MustCompile(`<<\s*EOF`),
|
||||
regexp.MustCompile(`\$\(\s*cat\s+`),
|
||||
regexp.MustCompile(`\$\(\s*curl\s+`),
|
||||
regexp.MustCompile(`\$\(\s*wget\s+`),
|
||||
regexp.MustCompile(`\$\(\s*which\s+`),
|
||||
regexp.MustCompile(`\bsudo\b`),
|
||||
regexp.MustCompile(`\bchmod\s+[0-7]{3,4}\b`),
|
||||
regexp.MustCompile(`\bchown\b`),
|
||||
regexp.MustCompile(`\bpkill\b`),
|
||||
regexp.MustCompile(`\bkillall\b`),
|
||||
regexp.MustCompile(`\bkill\s+-[9]\b`),
|
||||
regexp.MustCompile(`\bcurl\b.*\|\s*(sh|bash)`),
|
||||
regexp.MustCompile(`\bwget\b.*\|\s*(sh|bash)`),
|
||||
regexp.MustCompile(`\bnpm\s+install\s+-g\b`),
|
||||
regexp.MustCompile(`\bpip\s+install\s+--user\b`),
|
||||
regexp.MustCompile(`\bapt\s+(install|remove|purge)\b`),
|
||||
regexp.MustCompile(`\byum\s+(install|remove)\b`),
|
||||
regexp.MustCompile(`\bdnf\s+(install|remove)\b`),
|
||||
regexp.MustCompile(`\bdocker\s+run\b`),
|
||||
regexp.MustCompile(`\bdocker\s+exec\b`),
|
||||
regexp.MustCompile(`\bgit\s+push\b`),
|
||||
regexp.MustCompile(`\bgit\s+force\b`),
|
||||
regexp.MustCompile(`\bssh\b.*@`),
|
||||
regexp.MustCompile(`\beval\b`),
|
||||
regexp.MustCompile(`\bsource\s+.*\.sh\b`),
|
||||
}
|
||||
var (
|
||||
defaultDenyPatterns = []*regexp.Regexp{
|
||||
regexp.MustCompile(`\brm\s+-[rf]{1,2}\b`),
|
||||
regexp.MustCompile(`\bdel\s+/[fq]\b`),
|
||||
regexp.MustCompile(`\brmdir\s+/s\b`),
|
||||
// Match disk wiping commands, avoid matching --format flags
|
||||
regexp.MustCompile(
|
||||
`(?:^|[;&|]\s*|\s+)(format|mkfs|diskpart)\s`,
|
||||
),
|
||||
regexp.MustCompile(`\bdd\s+if=`),
|
||||
// Block writes to block devices (all common naming schemes).
|
||||
regexp.MustCompile(
|
||||
`>\s*/dev/(sd[a-z]|hd[a-z]|vd[a-z]|xvd[a-z]|nvme\d|mmcblk\d|loop\d|dm-\d|md\d|sr\d|nbd\d)`,
|
||||
),
|
||||
regexp.MustCompile(`\b(shutdown|reboot|poweroff)\b`),
|
||||
regexp.MustCompile(`:\(\)\s*\{.*\};\s*:`),
|
||||
regexp.MustCompile(`\$\([^)]+\)`),
|
||||
regexp.MustCompile(`\$\{[^}]+\}`),
|
||||
regexp.MustCompile("`[^`]+`"),
|
||||
regexp.MustCompile(`\|\s*sh\b`),
|
||||
regexp.MustCompile(`\|\s*bash\b`),
|
||||
regexp.MustCompile(`;\s*rm\s+-[rf]`),
|
||||
regexp.MustCompile(`&&\s*rm\s+-[rf]`),
|
||||
regexp.MustCompile(`\|\|\s*rm\s+-[rf]`),
|
||||
regexp.MustCompile(`<<\s*EOF`),
|
||||
regexp.MustCompile(`\$\(\s*cat\s+`),
|
||||
regexp.MustCompile(`\$\(\s*curl\s+`),
|
||||
regexp.MustCompile(`\$\(\s*wget\s+`),
|
||||
regexp.MustCompile(`\$\(\s*which\s+`),
|
||||
regexp.MustCompile(`\bsudo\b`),
|
||||
regexp.MustCompile(`\bchmod\s+[0-7]{3,4}\b`),
|
||||
regexp.MustCompile(`\bchown\b`),
|
||||
regexp.MustCompile(`\bpkill\b`),
|
||||
regexp.MustCompile(`\bkillall\b`),
|
||||
regexp.MustCompile(`\bkill\s+-[9]\b`),
|
||||
regexp.MustCompile(`\bcurl\b.*\|\s*(sh|bash)`),
|
||||
regexp.MustCompile(`\bwget\b.*\|\s*(sh|bash)`),
|
||||
regexp.MustCompile(`\bnpm\s+install\s+-g\b`),
|
||||
regexp.MustCompile(`\bpip\s+install\s+--user\b`),
|
||||
regexp.MustCompile(`\bapt\s+(install|remove|purge)\b`),
|
||||
regexp.MustCompile(`\byum\s+(install|remove)\b`),
|
||||
regexp.MustCompile(`\bdnf\s+(install|remove)\b`),
|
||||
regexp.MustCompile(`\bdocker\s+run\b`),
|
||||
regexp.MustCompile(`\bdocker\s+exec\b`),
|
||||
regexp.MustCompile(`\bgit\s+push\b`),
|
||||
regexp.MustCompile(`\bgit\s+force\b`),
|
||||
regexp.MustCompile(`\bssh\b.*@`),
|
||||
regexp.MustCompile(`\beval\b`),
|
||||
regexp.MustCompile(`\bsource\s+.*\.sh\b`),
|
||||
}
|
||||
|
||||
// absolutePathPattern matches absolute file paths in commands (Unix and Windows).
|
||||
absolutePathPattern = regexp.MustCompile(`[A-Za-z]:\\[^\\\"']+|/[^\s\"']+`)
|
||||
|
||||
// safePaths are kernel pseudo-devices that are always safe to reference in
|
||||
// commands, regardless of workspace restriction. They contain no user data
|
||||
// and cannot cause destructive writes.
|
||||
safePaths = map[string]bool{
|
||||
"/dev/null": true,
|
||||
"/dev/zero": true,
|
||||
"/dev/random": true,
|
||||
"/dev/urandom": true,
|
||||
"/dev/stdin": true,
|
||||
"/dev/stdout": true,
|
||||
"/dev/stderr": true,
|
||||
}
|
||||
)
|
||||
|
||||
func NewExecTool(workingDir string, restrict bool) (*ExecTool, error) {
|
||||
return NewExecToolWithConfig(workingDir, restrict, nil)
|
||||
@@ -75,6 +99,7 @@ func NewExecTool(workingDir string, restrict bool) (*ExecTool, error) {
|
||||
|
||||
func NewExecToolWithConfig(workingDir string, restrict bool, config *config.Config) (*ExecTool, error) {
|
||||
denyPatterns := make([]*regexp.Regexp, 0)
|
||||
customAllowPatterns := make([]*regexp.Regexp, 0)
|
||||
|
||||
if config != nil {
|
||||
execConfig := config.Tools.Exec
|
||||
@@ -95,6 +120,13 @@ func NewExecToolWithConfig(workingDir string, restrict bool, config *config.Conf
|
||||
// If deny patterns are disabled, we won't add any patterns, allowing all commands.
|
||||
fmt.Println("Warning: deny patterns are disabled. All commands will be allowed.")
|
||||
}
|
||||
for _, pattern := range execConfig.CustomAllowPatterns {
|
||||
re, err := regexp.Compile(pattern)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid custom allow pattern %q: %w", pattern, err)
|
||||
}
|
||||
customAllowPatterns = append(customAllowPatterns, re)
|
||||
}
|
||||
} else {
|
||||
denyPatterns = append(denyPatterns, defaultDenyPatterns...)
|
||||
}
|
||||
@@ -104,6 +136,7 @@ func NewExecToolWithConfig(workingDir string, restrict bool, config *config.Conf
|
||||
timeout: 60 * time.Second,
|
||||
denyPatterns: denyPatterns,
|
||||
allowPatterns: nil,
|
||||
customAllowPatterns: customAllowPatterns,
|
||||
restrictToWorkspace: restrict,
|
||||
}, nil
|
||||
}
|
||||
@@ -258,9 +291,20 @@ func (t *ExecTool) guardCommand(command, cwd string) string {
|
||||
cmd := strings.TrimSpace(command)
|
||||
lower := strings.ToLower(cmd)
|
||||
|
||||
for _, pattern := range t.denyPatterns {
|
||||
// Custom allow patterns exempt a command from deny checks.
|
||||
explicitlyAllowed := false
|
||||
for _, pattern := range t.customAllowPatterns {
|
||||
if pattern.MatchString(lower) {
|
||||
return "Command blocked by safety guard (dangerous pattern detected)"
|
||||
explicitlyAllowed = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !explicitlyAllowed {
|
||||
for _, pattern := range t.denyPatterns {
|
||||
if pattern.MatchString(lower) {
|
||||
return "Command blocked by safety guard (dangerous pattern detected)"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -287,16 +331,18 @@ func (t *ExecTool) guardCommand(command, cwd string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
pathPattern := regexp.MustCompile(`(?:^|\s|=)([A-Za-z]:\\[^\\"']+|/[a-zA-Z.][^\s"']*)`)
|
||||
matches := pathPattern.FindAllStringSubmatch(cmd, -1)
|
||||
matches := absolutePathPattern.FindAllString(cmd, -1)
|
||||
|
||||
for _, match := range matches {
|
||||
raw := match[1]
|
||||
for _, raw := range matches {
|
||||
p, err := filepath.Abs(raw)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if safePaths[p] {
|
||||
continue
|
||||
}
|
||||
|
||||
rel, err := filepath.Rel(cwdPath, p)
|
||||
if err != nil {
|
||||
continue
|
||||
|
||||
+106
-25
@@ -7,6 +7,8 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
// TestShellTool_Success verifies successful command execution
|
||||
@@ -310,6 +312,60 @@ func TestShellTool_RestrictToWorkspace(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestShellTool_DevNullAllowed verifies that /dev/null redirections are not blocked (issue #964).
|
||||
func TestShellTool_DevNullAllowed(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
tool, err := NewExecTool(tmpDir, true)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to configure exec tool: %s", err)
|
||||
}
|
||||
|
||||
commands := []string{
|
||||
"echo hello 2>/dev/null",
|
||||
"echo hello >/dev/null",
|
||||
"echo hello > /dev/null",
|
||||
"echo hello 2> /dev/null",
|
||||
"echo hello >/dev/null 2>&1",
|
||||
"find " + tmpDir + " -name '*.go' 2>/dev/null",
|
||||
}
|
||||
|
||||
for _, cmd := range commands {
|
||||
result := tool.Execute(context.Background(), map[string]any{"command": cmd})
|
||||
if result.IsError && strings.Contains(result.ForLLM, "blocked") {
|
||||
t.Errorf("command should not be blocked: %s\n error: %s", cmd, result.ForLLM)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestShellTool_BlockDevices verifies that writes to block devices are blocked (issue #965).
|
||||
func TestShellTool_BlockDevices(t *testing.T) {
|
||||
tool, err := NewExecTool("", false)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to configure exec tool: %s", err)
|
||||
}
|
||||
|
||||
blocked := []string{
|
||||
"echo x > /dev/sda",
|
||||
"echo x > /dev/hda",
|
||||
"echo x > /dev/vda",
|
||||
"echo x > /dev/xvda",
|
||||
"echo x > /dev/nvme0n1",
|
||||
"echo x > /dev/mmcblk0",
|
||||
"echo x > /dev/loop0",
|
||||
"echo x > /dev/dm-0",
|
||||
"echo x > /dev/md0",
|
||||
"echo x > /dev/sr0",
|
||||
"echo x > /dev/nbd0",
|
||||
}
|
||||
|
||||
for _, cmd := range blocked {
|
||||
result := tool.Execute(context.Background(), map[string]any{"command": cmd})
|
||||
if !result.IsError {
|
||||
t.Errorf("expected block device write to be blocked: %s", cmd)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestShellTool_DenyPattern_DiskWiping verifies the deny pattern for disk wiping
|
||||
// commands (format, mkfs, diskpart) blocks them when preceded by shell separators
|
||||
// but does NOT block legitimate uses like --format flags.
|
||||
@@ -322,7 +378,7 @@ func TestShellTool_DenyPattern_DiskWiping(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// These should be BLOCKED (disk wiping commands)
|
||||
blocked := []struct {
|
||||
blockedCmds := []struct {
|
||||
name string
|
||||
cmd string
|
||||
}{
|
||||
@@ -334,7 +390,7 @@ func TestShellTool_DenyPattern_DiskWiping(t *testing.T) {
|
||||
{"diskpart standalone", "diskpart /s script.txt"},
|
||||
}
|
||||
|
||||
for _, tt := range blocked {
|
||||
for _, tt := range blockedCmds {
|
||||
t.Run("blocked_"+tt.name, func(t *testing.T) {
|
||||
result := tool.Execute(ctx, map[string]any{"command": tt.cmd})
|
||||
if !result.IsError {
|
||||
@@ -362,35 +418,60 @@ func TestShellTool_DenyPattern_DiskWiping(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestShellTool_RestrictToWorkspace_HiddenDirs verifies that hidden directory
|
||||
// paths (starting with .) are properly detected by the workspace guard.
|
||||
func TestShellTool_RestrictToWorkspace_HiddenDirs(t *testing.T) {
|
||||
// TestShellTool_SafePathsInWorkspaceRestriction verifies that safe kernel pseudo-devices
|
||||
// are allowed even when workspace restriction is active.
|
||||
func TestShellTool_SafePathsInWorkspaceRestriction(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
tool, err := NewExecTool(tmpDir, false)
|
||||
tool, err := NewExecTool(tmpDir, true)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to configure exec tool: %s", err)
|
||||
}
|
||||
tool.SetRestrictToWorkspace(true)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Reading a hidden dir outside workspace should be blocked
|
||||
result := tool.Execute(ctx, map[string]any{
|
||||
"command": "cat /.ssh/config",
|
||||
})
|
||||
if !result.IsError {
|
||||
t.Errorf("Expected /.ssh/config to be blocked with restrictToWorkspace=true")
|
||||
// These reference paths outside workspace but should be allowed via safePaths.
|
||||
commands := []string{
|
||||
"cat /dev/urandom | head -c 16 | od",
|
||||
"echo test > /dev/null",
|
||||
"dd if=/dev/zero bs=1 count=1",
|
||||
}
|
||||
|
||||
// Flag-attached paths outside workspace should be blocked
|
||||
result2 := tool.Execute(ctx, map[string]any{
|
||||
"command": "grep --include=/etc/passwd pattern",
|
||||
})
|
||||
if !result2.IsError {
|
||||
// This tests the = delimiter fix; --include=/etc/passwd uses = in real
|
||||
// usage but --include /etc/passwd uses space. Both patterns should catch it.
|
||||
// If this specific form isn't blocked, it's acceptable since the primary
|
||||
// concern is the = form (--file=/etc/passwd).
|
||||
_ = result2 // acceptable either way for this pattern variant
|
||||
for _, cmd := range commands {
|
||||
result := tool.Execute(context.Background(), map[string]any{"command": cmd})
|
||||
if result.IsError && strings.Contains(result.ForLLM, "path outside working dir") {
|
||||
t.Errorf("safe path should not be blocked by workspace check: %s\n error: %s", cmd, result.ForLLM)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestShellTool_CustomAllowPatterns verifies that custom allow patterns exempt
|
||||
// commands from deny pattern checks.
|
||||
func TestShellTool_CustomAllowPatterns(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Tools: config.ToolsConfig{
|
||||
Exec: config.ExecConfig{
|
||||
EnableDenyPatterns: true,
|
||||
CustomAllowPatterns: []string{`\bgit\s+push\s+origin\b`},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
tool, err := NewExecToolWithConfig("", false, cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to configure exec tool: %s", err)
|
||||
}
|
||||
|
||||
// "git push origin main" should be allowed by custom allow pattern.
|
||||
result := tool.Execute(context.Background(), map[string]any{
|
||||
"command": "git push origin main",
|
||||
})
|
||||
if result.IsError && strings.Contains(result.ForLLM, "blocked") {
|
||||
t.Errorf("custom allow pattern should exempt 'git push origin main', got: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
// "git push upstream main" should still be blocked (does not match allow pattern).
|
||||
result = tool.Execute(context.Background(), map[string]any{
|
||||
"command": "git push upstream main",
|
||||
})
|
||||
if !result.IsError {
|
||||
t.Errorf("'git push upstream main' should still be blocked by deny pattern")
|
||||
}
|
||||
}
|
||||
|
||||
+86
-61
@@ -4,6 +4,7 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -15,6 +16,14 @@ import (
|
||||
|
||||
const (
|
||||
userAgent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36"
|
||||
|
||||
// HTTP client timeouts for web tool providers.
|
||||
searchTimeout = 10 * time.Second // Brave, Tavily, DuckDuckGo
|
||||
perplexityTimeout = 30 * time.Second // Perplexity (LLM-based, slower)
|
||||
fetchTimeout = 60 * time.Second // WebFetchTool
|
||||
|
||||
defaultMaxChars = 50000
|
||||
maxRedirects = 5
|
||||
)
|
||||
|
||||
// Pre-compiled regexes for HTML text extraction
|
||||
@@ -74,6 +83,7 @@ type SearchProvider interface {
|
||||
type BraveSearchProvider struct {
|
||||
apiKey string
|
||||
proxy string
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
func (p *BraveSearchProvider) Search(ctx context.Context, query string, count int) (string, error) {
|
||||
@@ -88,11 +98,7 @@ func (p *BraveSearchProvider) Search(ctx context.Context, query string, count in
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("X-Subscription-Token", p.apiKey)
|
||||
|
||||
client, err := createHTTPClient(p.proxy, 10*time.Second)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create HTTP client: %w", err)
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
resp, err := p.client.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
@@ -103,6 +109,10 @@ func (p *BraveSearchProvider) Search(ctx context.Context, query string, count in
|
||||
return "", fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("brave api error (status %d): %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var searchResp struct {
|
||||
Web struct {
|
||||
Results []struct {
|
||||
@@ -143,6 +153,7 @@ type TavilySearchProvider struct {
|
||||
apiKey string
|
||||
baseURL string
|
||||
proxy string
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
func (p *TavilySearchProvider) Search(ctx context.Context, query string, count int) (string, error) {
|
||||
@@ -174,11 +185,7 @@ func (p *TavilySearchProvider) Search(ctx context.Context, query string, count i
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", userAgent)
|
||||
|
||||
client, err := createHTTPClient(p.proxy, 10*time.Second)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create HTTP client: %w", err)
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
resp, err := p.client.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
@@ -226,7 +233,8 @@ func (p *TavilySearchProvider) Search(ctx context.Context, query string, count i
|
||||
}
|
||||
|
||||
type DuckDuckGoSearchProvider struct {
|
||||
proxy string
|
||||
proxy string
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
func (p *DuckDuckGoSearchProvider) Search(ctx context.Context, query string, count int) (string, error) {
|
||||
@@ -239,11 +247,7 @@ func (p *DuckDuckGoSearchProvider) Search(ctx context.Context, query string, cou
|
||||
|
||||
req.Header.Set("User-Agent", userAgent)
|
||||
|
||||
client, err := createHTTPClient(p.proxy, 10*time.Second)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create HTTP client: %w", err)
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
resp, err := p.client.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
@@ -285,7 +289,7 @@ func (p *DuckDuckGoSearchProvider) extractResults(html string, count int, query
|
||||
|
||||
maxItems := min(len(matches), count)
|
||||
|
||||
for i := 0; i < maxItems; i++ {
|
||||
for i := range maxItems {
|
||||
urlStr := matches[i][1]
|
||||
title := stripTags(matches[i][2])
|
||||
title = strings.TrimSpace(title)
|
||||
@@ -293,9 +297,9 @@ func (p *DuckDuckGoSearchProvider) extractResults(html string, count int, query
|
||||
// URL decoding if needed
|
||||
if strings.Contains(urlStr, "uddg=") {
|
||||
if u, err := url.QueryUnescape(urlStr); err == nil {
|
||||
idx := strings.Index(u, "uddg=")
|
||||
if idx != -1 {
|
||||
urlStr = u[idx+5:]
|
||||
_, after, ok := strings.Cut(u, "uddg=")
|
||||
if ok {
|
||||
urlStr = after
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -322,6 +326,7 @@ func stripTags(content string) string {
|
||||
type PerplexitySearchProvider struct {
|
||||
apiKey string
|
||||
proxy string
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
func (p *PerplexitySearchProvider) Search(ctx context.Context, query string, count int) (string, error) {
|
||||
@@ -356,11 +361,7 @@ func (p *PerplexitySearchProvider) Search(ctx context.Context, query string, cou
|
||||
req.Header.Set("Authorization", "Bearer "+p.apiKey)
|
||||
req.Header.Set("User-Agent", userAgent)
|
||||
|
||||
client, err := createHTTPClient(p.proxy, 30*time.Second)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create HTTP client: %w", err)
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
resp, err := p.client.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
@@ -415,43 +416,60 @@ type WebSearchToolOptions struct {
|
||||
Proxy string
|
||||
}
|
||||
|
||||
func NewWebSearchTool(opts WebSearchToolOptions) *WebSearchTool {
|
||||
func NewWebSearchTool(opts WebSearchToolOptions) (*WebSearchTool, error) {
|
||||
var provider SearchProvider
|
||||
maxResults := 5
|
||||
|
||||
// Priority: Perplexity > Brave > Tavily > DuckDuckGo
|
||||
if opts.PerplexityEnabled && opts.PerplexityAPIKey != "" {
|
||||
provider = &PerplexitySearchProvider{apiKey: opts.PerplexityAPIKey, proxy: opts.Proxy}
|
||||
client, err := createHTTPClient(opts.Proxy, perplexityTimeout)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create HTTP client for Perplexity: %w", err)
|
||||
}
|
||||
provider = &PerplexitySearchProvider{apiKey: opts.PerplexityAPIKey, proxy: opts.Proxy, client: client}
|
||||
if opts.PerplexityMaxResults > 0 {
|
||||
maxResults = opts.PerplexityMaxResults
|
||||
}
|
||||
} else if opts.BraveEnabled && opts.BraveAPIKey != "" {
|
||||
provider = &BraveSearchProvider{apiKey: opts.BraveAPIKey, proxy: opts.Proxy}
|
||||
client, err := createHTTPClient(opts.Proxy, searchTimeout)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create HTTP client for Brave: %w", err)
|
||||
}
|
||||
provider = &BraveSearchProvider{apiKey: opts.BraveAPIKey, proxy: opts.Proxy, client: client}
|
||||
if opts.BraveMaxResults > 0 {
|
||||
maxResults = opts.BraveMaxResults
|
||||
}
|
||||
} else if opts.TavilyEnabled && opts.TavilyAPIKey != "" {
|
||||
client, err := createHTTPClient(opts.Proxy, searchTimeout)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create HTTP client for Tavily: %w", err)
|
||||
}
|
||||
provider = &TavilySearchProvider{
|
||||
apiKey: opts.TavilyAPIKey,
|
||||
baseURL: opts.TavilyBaseURL,
|
||||
proxy: opts.Proxy,
|
||||
client: client,
|
||||
}
|
||||
if opts.TavilyMaxResults > 0 {
|
||||
maxResults = opts.TavilyMaxResults
|
||||
}
|
||||
} else if opts.DuckDuckGoEnabled {
|
||||
provider = &DuckDuckGoSearchProvider{proxy: opts.Proxy}
|
||||
client, err := createHTTPClient(opts.Proxy, searchTimeout)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create HTTP client for DuckDuckGo: %w", err)
|
||||
}
|
||||
provider = &DuckDuckGoSearchProvider{proxy: opts.Proxy, client: client}
|
||||
if opts.DuckDuckGoMaxResults > 0 {
|
||||
maxResults = opts.DuckDuckGoMaxResults
|
||||
}
|
||||
} else {
|
||||
return nil
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return &WebSearchTool{
|
||||
provider: provider,
|
||||
maxResults: maxResults,
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (t *WebSearchTool) Name() string {
|
||||
@@ -506,27 +524,40 @@ func (t *WebSearchTool) Execute(ctx context.Context, args map[string]any) *ToolR
|
||||
}
|
||||
|
||||
type WebFetchTool struct {
|
||||
maxChars int
|
||||
proxy string
|
||||
maxChars int
|
||||
proxy string
|
||||
client *http.Client
|
||||
fetchLimitBytes int64
|
||||
}
|
||||
|
||||
func NewWebFetchTool(maxChars int) *WebFetchTool {
|
||||
if maxChars <= 0 {
|
||||
maxChars = 50000
|
||||
}
|
||||
return &WebFetchTool{
|
||||
maxChars: maxChars,
|
||||
}
|
||||
func NewWebFetchTool(maxChars int, fetchLimitBytes int64) (*WebFetchTool, error) {
|
||||
// createHTTPClient cannot fail with an empty proxy string.
|
||||
return NewWebFetchToolWithProxy(maxChars, "", fetchLimitBytes)
|
||||
}
|
||||
|
||||
func NewWebFetchToolWithProxy(maxChars int, proxy string) *WebFetchTool {
|
||||
func NewWebFetchToolWithProxy(maxChars int, proxy string, fetchLimitBytes int64) (*WebFetchTool, error) {
|
||||
if maxChars <= 0 {
|
||||
maxChars = 50000
|
||||
maxChars = defaultMaxChars
|
||||
}
|
||||
client, err := createHTTPClient(proxy, fetchTimeout)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create HTTP client for web fetch: %w", err)
|
||||
}
|
||||
client.CheckRedirect = func(req *http.Request, via []*http.Request) error {
|
||||
if len(via) >= maxRedirects {
|
||||
return fmt.Errorf("stopped after %d redirects", maxRedirects)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if fetchLimitBytes <= 0 {
|
||||
fetchLimitBytes = 10 * 1024 * 1024 // Security Fallback
|
||||
}
|
||||
return &WebFetchTool{
|
||||
maxChars: maxChars,
|
||||
proxy: proxy,
|
||||
}
|
||||
maxChars: maxChars,
|
||||
proxy: proxy,
|
||||
client: client,
|
||||
fetchLimitBytes: fetchLimitBytes,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (t *WebFetchTool) Name() string {
|
||||
@@ -588,27 +619,21 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]any) *ToolRe
|
||||
|
||||
req.Header.Set("User-Agent", userAgent)
|
||||
|
||||
client, err := createHTTPClient(t.proxy, 60*time.Second)
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("failed to create HTTP client: %v", err))
|
||||
}
|
||||
|
||||
// Configure redirect handling
|
||||
client.CheckRedirect = func(req *http.Request, via []*http.Request) error {
|
||||
if len(via) >= 5 {
|
||||
return fmt.Errorf("stopped after 5 redirects")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
resp, err := t.client.Do(req)
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("request failed: %v", err))
|
||||
}
|
||||
|
||||
resp.Body = http.MaxBytesReader(nil, resp.Body, t.fetchLimitBytes)
|
||||
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
var maxBytesErr *http.MaxBytesError
|
||||
if errors.As(err, &maxBytesErr) {
|
||||
return ErrorResult(fmt.Sprintf("failed to read response: size exceeded %d bytes limit", t.fetchLimitBytes))
|
||||
}
|
||||
return ErrorResult(fmt.Sprintf("failed to read response: %v", err))
|
||||
}
|
||||
|
||||
@@ -652,14 +677,14 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]any) *ToolRe
|
||||
resultJSON, _ := json.MarshalIndent(result, "", " ")
|
||||
|
||||
return &ToolResult{
|
||||
ForLLM: fmt.Sprintf(
|
||||
ForLLM: string(resultJSON),
|
||||
ForUser: fmt.Sprintf(
|
||||
"Fetched %d bytes from %s (extractor: %s, truncated: %v)",
|
||||
len(text),
|
||||
urlStr,
|
||||
extractor,
|
||||
truncated,
|
||||
),
|
||||
ForUser: string(resultJSON),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
+144
-35
@@ -1,15 +1,21 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
)
|
||||
|
||||
const testFetchLimit = int64(10 * 1024 * 1024)
|
||||
|
||||
// TestWebTool_WebFetch_Success verifies successful URL fetching
|
||||
func TestWebTool_WebFetch_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -19,7 +25,11 @@ func TestWebTool_WebFetch_Success(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
tool := NewWebFetchTool(50000)
|
||||
tool, err := NewWebFetchTool(50000, testFetchLimit)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create web fetch tool: %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]any{
|
||||
"url": server.URL,
|
||||
@@ -32,14 +42,14 @@ func TestWebTool_WebFetch_Success(t *testing.T) {
|
||||
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
// ForUser should contain the fetched content
|
||||
if !strings.Contains(result.ForUser, "Test Page") {
|
||||
t.Errorf("Expected ForUser to contain 'Test Page', got: %s", result.ForUser)
|
||||
// ForLLM should contain the fetched content (full JSON result)
|
||||
if !strings.Contains(result.ForLLM, "Test Page") {
|
||||
t.Errorf("Expected ForLLM to contain 'Test Page', got: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
// ForLLM should contain summary
|
||||
if !strings.Contains(result.ForLLM, "bytes") && !strings.Contains(result.ForLLM, "extractor") {
|
||||
t.Errorf("Expected ForLLM to contain summary, got: %s", result.ForLLM)
|
||||
// ForUser should contain summary
|
||||
if !strings.Contains(result.ForUser, "bytes") && !strings.Contains(result.ForUser, "extractor") {
|
||||
t.Errorf("Expected ForUser to contain summary, got: %s", result.ForUser)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -55,7 +65,11 @@ func TestWebTool_WebFetch_JSON(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
tool := NewWebFetchTool(50000)
|
||||
tool, err := NewWebFetchTool(50000, testFetchLimit)
|
||||
if err != nil {
|
||||
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]any{
|
||||
"url": server.URL,
|
||||
@@ -68,15 +82,19 @@ func TestWebTool_WebFetch_JSON(t *testing.T) {
|
||||
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
// ForUser should contain formatted JSON
|
||||
if !strings.Contains(result.ForUser, "key") && !strings.Contains(result.ForUser, "value") {
|
||||
t.Errorf("Expected ForUser to contain JSON data, got: %s", result.ForUser)
|
||||
// ForLLM should contain formatted JSON
|
||||
if !strings.Contains(result.ForLLM, "key") && !strings.Contains(result.ForLLM, "value") {
|
||||
t.Errorf("Expected ForLLM to contain JSON data, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebTool_WebFetch_InvalidURL verifies error handling for invalid URL
|
||||
func TestWebTool_WebFetch_InvalidURL(t *testing.T) {
|
||||
tool := NewWebFetchTool(50000)
|
||||
tool, err := NewWebFetchTool(50000, testFetchLimit)
|
||||
if err != nil {
|
||||
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]any{
|
||||
"url": "not-a-valid-url",
|
||||
@@ -97,7 +115,11 @@ func TestWebTool_WebFetch_InvalidURL(t *testing.T) {
|
||||
|
||||
// TestWebTool_WebFetch_UnsupportedScheme verifies error handling for non-http URLs
|
||||
func TestWebTool_WebFetch_UnsupportedScheme(t *testing.T) {
|
||||
tool := NewWebFetchTool(50000)
|
||||
tool, err := NewWebFetchTool(50000, testFetchLimit)
|
||||
if err != nil {
|
||||
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]any{
|
||||
"url": "ftp://example.com/file.txt",
|
||||
@@ -118,7 +140,11 @@ func TestWebTool_WebFetch_UnsupportedScheme(t *testing.T) {
|
||||
|
||||
// TestWebTool_WebFetch_MissingURL verifies error handling for missing URL
|
||||
func TestWebTool_WebFetch_MissingURL(t *testing.T) {
|
||||
tool := NewWebFetchTool(50000)
|
||||
tool, err := NewWebFetchTool(50000, testFetchLimit)
|
||||
if err != nil {
|
||||
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]any{}
|
||||
|
||||
@@ -146,7 +172,11 @@ func TestWebTool_WebFetch_Truncation(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
tool := NewWebFetchTool(1000) // Limit to 1000 chars
|
||||
tool, err := NewWebFetchTool(1000, testFetchLimit) // Limit to 1000 chars
|
||||
if err != nil {
|
||||
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]any{
|
||||
"url": server.URL,
|
||||
@@ -159,9 +189,9 @@ func TestWebTool_WebFetch_Truncation(t *testing.T) {
|
||||
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
// ForUser should contain truncated content (not the full 20000 chars)
|
||||
// ForLLM should contain truncated content (not the full 20000 chars)
|
||||
resultMap := make(map[string]any)
|
||||
json.Unmarshal([]byte(result.ForUser), &resultMap)
|
||||
json.Unmarshal([]byte(result.ForLLM), &resultMap)
|
||||
if text, ok := resultMap["text"].(string); ok {
|
||||
if len(text) > 1100 { // Allow some margin
|
||||
t.Errorf("Expected content to be truncated to ~1000 chars, got: %d", len(text))
|
||||
@@ -174,15 +204,64 @@ func TestWebTool_WebFetch_Truncation(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebFetchTool_PayloadTooLarge(t *testing.T) {
|
||||
// Create a mock HTTP server
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/html")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
||||
// Generate a payload intentionally larger than our limit.
|
||||
// Limit: 10 * 1024 * 1024 (10MB). We generate 10MB + 100 bytes of the letter 'A'.
|
||||
largeData := bytes.Repeat([]byte("A"), int(testFetchLimit)+100)
|
||||
|
||||
w.Write(largeData)
|
||||
}))
|
||||
// Ensure the server is shut down at the end of the test
|
||||
defer ts.Close()
|
||||
|
||||
// Initialize the tool
|
||||
tool, err := NewWebFetchTool(50000, testFetchLimit)
|
||||
if err != nil {
|
||||
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
|
||||
}
|
||||
|
||||
// Prepare the arguments pointing to the URL of our local mock server
|
||||
args := map[string]any{
|
||||
"url": ts.URL,
|
||||
}
|
||||
|
||||
// Execute the tool
|
||||
ctx := context.Background()
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Assuming ErrorResult sets the ForLLM field with the error text.
|
||||
if result == nil {
|
||||
t.Fatal("expected a ToolResult, got nil")
|
||||
}
|
||||
|
||||
// Search for the exact error string we set earlier in the Execute method
|
||||
expectedErrorMsg := fmt.Sprintf("size exceeded %d bytes limit", testFetchLimit)
|
||||
|
||||
if !strings.Contains(result.ForLLM, expectedErrorMsg) && !strings.Contains(result.ForUser, expectedErrorMsg) {
|
||||
t.Errorf("test failed: expected error %q, but got: %+v", expectedErrorMsg, result)
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebTool_WebSearch_NoApiKey verifies that no tool is created when API key is missing
|
||||
func TestWebTool_WebSearch_NoApiKey(t *testing.T) {
|
||||
tool := NewWebSearchTool(WebSearchToolOptions{BraveEnabled: true, BraveAPIKey: ""})
|
||||
tool, err := NewWebSearchTool(WebSearchToolOptions{BraveEnabled: true, BraveAPIKey: ""})
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
if tool != nil {
|
||||
t.Errorf("Expected nil tool when Brave API key is empty")
|
||||
}
|
||||
|
||||
// Also nil when nothing is enabled
|
||||
tool = NewWebSearchTool(WebSearchToolOptions{})
|
||||
tool, err = NewWebSearchTool(WebSearchToolOptions{})
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
if tool != nil {
|
||||
t.Errorf("Expected nil tool when no provider is enabled")
|
||||
}
|
||||
@@ -190,7 +269,10 @@ func TestWebTool_WebSearch_NoApiKey(t *testing.T) {
|
||||
|
||||
// TestWebTool_WebSearch_MissingQuery verifies error handling for missing query
|
||||
func TestWebTool_WebSearch_MissingQuery(t *testing.T) {
|
||||
tool := NewWebSearchTool(WebSearchToolOptions{BraveEnabled: true, BraveAPIKey: "test-key", BraveMaxResults: 5})
|
||||
tool, err := NewWebSearchTool(WebSearchToolOptions{BraveEnabled: true, BraveAPIKey: "test-key", BraveMaxResults: 5})
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
ctx := context.Background()
|
||||
args := map[string]any{}
|
||||
|
||||
@@ -215,7 +297,11 @@ func TestWebTool_WebFetch_HTMLExtraction(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
tool := NewWebFetchTool(50000)
|
||||
tool, err := NewWebFetchTool(50000, testFetchLimit)
|
||||
if err != nil {
|
||||
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]any{
|
||||
"url": server.URL,
|
||||
@@ -228,14 +314,14 @@ func TestWebTool_WebFetch_HTMLExtraction(t *testing.T) {
|
||||
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
// ForUser should contain extracted text (without script/style tags)
|
||||
if !strings.Contains(result.ForUser, "Title") && !strings.Contains(result.ForUser, "Content") {
|
||||
t.Errorf("Expected ForUser to contain extracted text, got: %s", result.ForUser)
|
||||
// ForLLM should contain extracted text (without script/style tags)
|
||||
if !strings.Contains(result.ForLLM, "Title") && !strings.Contains(result.ForLLM, "Content") {
|
||||
t.Errorf("Expected ForLLM to contain extracted text, got: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
// Should NOT contain script or style tags
|
||||
if strings.Contains(result.ForUser, "<script>") || strings.Contains(result.ForUser, "<style>") {
|
||||
t.Errorf("Expected script/style tags to be removed, got: %s", result.ForUser)
|
||||
// Should NOT contain script or style tags in ForLLM
|
||||
if strings.Contains(result.ForLLM, "<script>") || strings.Contains(result.ForLLM, "<style>") {
|
||||
t.Errorf("Expected script/style tags to be removed, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -316,7 +402,11 @@ func TestWebFetchTool_extractText(t *testing.T) {
|
||||
|
||||
// TestWebTool_WebFetch_MissingDomain verifies error handling for URL without domain
|
||||
func TestWebTool_WebFetch_MissingDomain(t *testing.T) {
|
||||
tool := NewWebFetchTool(50000)
|
||||
tool, err := NewWebFetchTool(50000, testFetchLimit)
|
||||
if err != nil {
|
||||
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]any{
|
||||
"url": "https://",
|
||||
@@ -438,15 +528,22 @@ func TestCreateHTTPClient_ProxyFromEnvironmentWhenConfigEmpty(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestNewWebFetchToolWithProxy(t *testing.T) {
|
||||
tool := NewWebFetchToolWithProxy(1024, "http://127.0.0.1:7890")
|
||||
if tool.maxChars != 1024 {
|
||||
tool, err := NewWebFetchToolWithProxy(1024, "http://127.0.0.1:7890", testFetchLimit)
|
||||
if err != nil {
|
||||
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
|
||||
} else if tool.maxChars != 1024 {
|
||||
t.Fatalf("maxChars = %d, want %d", tool.maxChars, 1024)
|
||||
}
|
||||
|
||||
if tool.proxy != "http://127.0.0.1:7890" {
|
||||
t.Fatalf("proxy = %q, want %q", tool.proxy, "http://127.0.0.1:7890")
|
||||
}
|
||||
|
||||
tool = NewWebFetchToolWithProxy(0, "http://127.0.0.1:7890")
|
||||
tool, err = NewWebFetchToolWithProxy(0, "http://127.0.0.1:7890", testFetchLimit)
|
||||
if err != nil {
|
||||
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
|
||||
}
|
||||
|
||||
if tool.maxChars != 50000 {
|
||||
t.Fatalf("default maxChars = %d, want %d", tool.maxChars, 50000)
|
||||
}
|
||||
@@ -454,12 +551,15 @@ func TestNewWebFetchToolWithProxy(t *testing.T) {
|
||||
|
||||
func TestNewWebSearchTool_PropagatesProxy(t *testing.T) {
|
||||
t.Run("perplexity", func(t *testing.T) {
|
||||
tool := NewWebSearchTool(WebSearchToolOptions{
|
||||
tool, err := NewWebSearchTool(WebSearchToolOptions{
|
||||
PerplexityEnabled: true,
|
||||
PerplexityAPIKey: "k",
|
||||
PerplexityMaxResults: 3,
|
||||
Proxy: "http://127.0.0.1:7890",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewWebSearchTool() error: %v", err)
|
||||
}
|
||||
p, ok := tool.provider.(*PerplexitySearchProvider)
|
||||
if !ok {
|
||||
t.Fatalf("provider type = %T, want *PerplexitySearchProvider", tool.provider)
|
||||
@@ -470,12 +570,15 @@ func TestNewWebSearchTool_PropagatesProxy(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("brave", func(t *testing.T) {
|
||||
tool := NewWebSearchTool(WebSearchToolOptions{
|
||||
tool, err := NewWebSearchTool(WebSearchToolOptions{
|
||||
BraveEnabled: true,
|
||||
BraveAPIKey: "k",
|
||||
BraveMaxResults: 3,
|
||||
Proxy: "http://127.0.0.1:7890",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewWebSearchTool() error: %v", err)
|
||||
}
|
||||
p, ok := tool.provider.(*BraveSearchProvider)
|
||||
if !ok {
|
||||
t.Fatalf("provider type = %T, want *BraveSearchProvider", tool.provider)
|
||||
@@ -486,11 +589,14 @@ func TestNewWebSearchTool_PropagatesProxy(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("duckduckgo", func(t *testing.T) {
|
||||
tool := NewWebSearchTool(WebSearchToolOptions{
|
||||
tool, err := NewWebSearchTool(WebSearchToolOptions{
|
||||
DuckDuckGoEnabled: true,
|
||||
DuckDuckGoMaxResults: 3,
|
||||
Proxy: "http://127.0.0.1:7890",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewWebSearchTool() error: %v", err)
|
||||
}
|
||||
p, ok := tool.provider.(*DuckDuckGoSearchProvider)
|
||||
if !ok {
|
||||
t.Fatalf("provider type = %T, want *DuckDuckGoSearchProvider", tool.provider)
|
||||
@@ -542,12 +648,15 @@ func TestWebTool_TavilySearch_Success(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
tool := NewWebSearchTool(WebSearchToolOptions{
|
||||
tool, err := NewWebSearchTool(WebSearchToolOptions{
|
||||
TavilyEnabled: true,
|
||||
TavilyAPIKey: "test-key",
|
||||
TavilyBaseURL: server.URL,
|
||||
TavilyMaxResults: 5,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewWebSearchTool() error: %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]any{
|
||||
|
||||
Reference in New Issue
Block a user