diff --git a/pkg/tools/filesystem.go b/pkg/tools/filesystem.go index cd8da3195..3c518dd94 100644 --- a/pkg/tools/filesystem.go +++ b/pkg/tools/filesystem.go @@ -1,9 +1,12 @@ package tools import ( + "bytes" "context" "fmt" + "io" "io/fs" + "net/http" "os" "path/filepath" "regexp" @@ -123,11 +126,37 @@ func (t *ReadFileTool) Execute(ctx context.Context, args map[string]any) *ToolRe return ErrorResult("path is required") } - content, err := t.fs.ReadFile(path) + // open file instead of loading it all into memory + file, err := t.fs.Open(path) if err != nil { return ErrorResult(err.Error()) } - return NewToolResult(string(content)) + defer file.Close() + + // read only an initial chunk (512 bytes is the standard for MIME sniffing) + header := make([]byte, 512) + n, err := file.Read(header) + if err != nil && err != io.EOF { + return ErrorResult(fmt.Sprintf("failed to read file header: %v", err)) + } + header = header[:n] + + // Lock the binaries now before using more RAM + if isBinaryFile(header) { + return ErrorResult(fmt.Sprintf("cannot read file %q: appears to be a binary file (e.g., PDF, image, executable)", filepath.Base(path))) + } + + // If it is text, let's read the rest of the file + // (io.ReadAll will continue reading starting from byte 513) + rest, err := io.ReadAll(file) + if err != nil { + return ErrorResult(fmt.Sprintf("failed to read file content: %v", err)) + } + + // Recompose the complete content by merging the header and the rest + fullContent := append(header, rest...) + + return NewToolResult(string(fullContent)) } type WriteFileTool struct { @@ -249,6 +278,7 @@ type fileSystem interface { ReadFile(path string) ([]byte, error) WriteFile(path string, data []byte) error ReadDir(path string) ([]os.DirEntry, error) + Open(path string) (fs.File, error) } // hostFs is an unrestricted fileReadWriter that operates directly on the host filesystem. @@ -278,6 +308,20 @@ func (h *hostFs) WriteFile(path string, data []byte) error { return fileutil.WriteFileAtomic(path, data, 0o600) } +func (h *hostFs) Open(path string) (fs.File, error) { + f, err := os.Open(path) + if err != nil { + if os.IsNotExist(err) { + return nil, fmt.Errorf("failed to open file: file not found: %w", err) + } + if os.IsPermission(err) { + return nil, fmt.Errorf("failed to open file: access denied: %w", err) + } + return nil, fmt.Errorf("failed to open file: %w", err) + } + return f, nil +} + // sandboxFs is a sandboxed fileSystem that operates within a strictly defined workspace using os.Root. type sandboxFs struct { workspace string @@ -389,6 +433,26 @@ func (r *sandboxFs) ReadDir(path string) ([]os.DirEntry, error) { return entries, err } +func (r *sandboxFs) Open(path string) (fs.File, error) { + var f fs.File + err := r.execute(path, func(root *os.Root, relPath string) error { + file, err := root.Open(relPath) + if err != nil { + if os.IsNotExist(err) { + return fmt.Errorf("failed to open file: file not found: %w", err) + } + if os.IsPermission(err) || strings.Contains(err.Error(), "escapes from parent") || + strings.Contains(err.Error(), "permission denied") { + return fmt.Errorf("failed to open file: access denied: %w", err) + } + return fmt.Errorf("failed to open file: %w", err) + } + f = file + return nil + }) + return f, err +} + // whitelistFs wraps a sandboxFs and allows access to specific paths outside // the workspace when they match any of the provided patterns. type whitelistFs struct { @@ -427,6 +491,13 @@ func (w *whitelistFs) ReadDir(path string) ([]os.DirEntry, error) { return w.sandbox.ReadDir(path) } +func (w *whitelistFs) Open(path string) (fs.File, error) { + if w.matches(path) { + return w.host.Open(path) + } + return w.sandbox.Open(path) +} + // buildFs returns the appropriate fileSystem implementation based on restriction // settings and optional path whitelist patterns. func buildFs(workspace string, restrict bool, patterns []*regexp.Regexp) fileSystem { @@ -461,3 +532,33 @@ func getSafeRelPath(workspace, path string) (string, error) { return rel, nil } + +// isBinaryFile uses common heuristics to determine if the content is a binary file. +func isBinaryFile(content []byte) bool { + if len(content) == 0 { + return false + } + + // Sample the first 512 bytes (or less if the file is smaller) + limit := len(content) + if limit > 512 { + limit = 512 + } + sample := content[:limit] + + // Check for NUL bytes in the sample (standard binary detection) + if bytes.IndexByte(sample, 0) != -1 { + return true + } + + // Use standard library content type detection to catch specific formats like PDF + contentType := http.DetectContentType(sample) + if contentType == "application/pdf" || + strings.HasPrefix(contentType, "image/") || + strings.HasPrefix(contentType, "video/") || + strings.HasPrefix(contentType, "audio/") { + return true + } + + return false +} diff --git a/pkg/tools/filesystem_test.go b/pkg/tools/filesystem_test.go index 666004cd4..2868431e0 100644 --- a/pkg/tools/filesystem_test.go +++ b/pkg/tools/filesystem_test.go @@ -59,7 +59,7 @@ func TestFilesystemTool_ReadFile_NotFound(t *testing.T) { } // Should contain error message - if !strings.Contains(result.ForLLM, "failed to read") && !strings.Contains(result.ForUser, "failed to read") { + if !strings.Contains(result.ForLLM, "failed to open file") && !strings.Contains(result.ForUser, "failed to read") { t.Errorf("Expected error message, got ForLLM: %s, ForUser: %s", result.ForLLM, result.ForUser) } } @@ -520,3 +520,93 @@ func TestWhitelistFs_AllowsMatchingPaths(t *testing.T) { t.Errorf("expected non-whitelisted path to be blocked, got: %s", result.ForLLM) } } + +func TestIsBinaryFile(t *testing.T) { + tests := []struct { + name string + content []byte + expected bool + }{ + { + name: "empty content", + content: []byte(""), + expected: false, + }, + { + name: "plain text", + content: []byte("This is a normal text file with punctuation and 12345 numbers."), + expected: false, + }, + { + name: "contains null byte", + content: []byte("plain text\x00followed by a null byte"), + expected: true, + }, + { + name: "pdf header", + content: []byte("%PDF-1.4\n%\xE2\xE3\xCF\xD3\n1 0 obj\n<>"), + expected: true, + }, + { + name: "png magic bytes", + content: []byte("\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x01\x00"), + expected: true, + }, + { + name: "jpeg magic bytes", + content: []byte("\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x01\x00H"), + expected: true, + }, + { + name: "html text (not binary)", + content: []byte("