mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
prevent read binary file in tool
This commit is contained in:
+103
-2
@@ -1,9 +1,12 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
@@ -123,11 +126,37 @@ func (t *ReadFileTool) Execute(ctx context.Context, args map[string]any) *ToolRe
|
||||
return ErrorResult("path is required")
|
||||
}
|
||||
|
||||
content, err := t.fs.ReadFile(path)
|
||||
// open file instead of loading it all into memory
|
||||
file, err := t.fs.Open(path)
|
||||
if err != nil {
|
||||
return ErrorResult(err.Error())
|
||||
}
|
||||
return NewToolResult(string(content))
|
||||
defer file.Close()
|
||||
|
||||
// read only an initial chunk (512 bytes is the standard for MIME sniffing)
|
||||
header := make([]byte, 512)
|
||||
n, err := file.Read(header)
|
||||
if err != nil && err != io.EOF {
|
||||
return ErrorResult(fmt.Sprintf("failed to read file header: %v", err))
|
||||
}
|
||||
header = header[:n]
|
||||
|
||||
// Lock the binaries now before using more RAM
|
||||
if isBinaryFile(header) {
|
||||
return ErrorResult(fmt.Sprintf("cannot read file %q: appears to be a binary file (e.g., PDF, image, executable)", filepath.Base(path)))
|
||||
}
|
||||
|
||||
// If it is text, let's read the rest of the file
|
||||
// (io.ReadAll will continue reading starting from byte 513)
|
||||
rest, err := io.ReadAll(file)
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("failed to read file content: %v", err))
|
||||
}
|
||||
|
||||
// Recompose the complete content by merging the header and the rest
|
||||
fullContent := append(header, rest...)
|
||||
|
||||
return NewToolResult(string(fullContent))
|
||||
}
|
||||
|
||||
type WriteFileTool struct {
|
||||
@@ -249,6 +278,7 @@ type fileSystem interface {
|
||||
ReadFile(path string) ([]byte, error)
|
||||
WriteFile(path string, data []byte) error
|
||||
ReadDir(path string) ([]os.DirEntry, error)
|
||||
Open(path string) (fs.File, error)
|
||||
}
|
||||
|
||||
// hostFs is an unrestricted fileReadWriter that operates directly on the host filesystem.
|
||||
@@ -278,6 +308,20 @@ func (h *hostFs) WriteFile(path string, data []byte) error {
|
||||
return fileutil.WriteFileAtomic(path, data, 0o600)
|
||||
}
|
||||
|
||||
func (h *hostFs) Open(path string) (fs.File, error) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil, fmt.Errorf("failed to open file: file not found: %w", err)
|
||||
}
|
||||
if os.IsPermission(err) {
|
||||
return nil, fmt.Errorf("failed to open file: access denied: %w", err)
|
||||
}
|
||||
return nil, fmt.Errorf("failed to open file: %w", err)
|
||||
}
|
||||
return f, nil
|
||||
}
|
||||
|
||||
// sandboxFs is a sandboxed fileSystem that operates within a strictly defined workspace using os.Root.
|
||||
type sandboxFs struct {
|
||||
workspace string
|
||||
@@ -389,6 +433,26 @@ func (r *sandboxFs) ReadDir(path string) ([]os.DirEntry, error) {
|
||||
return entries, err
|
||||
}
|
||||
|
||||
func (r *sandboxFs) Open(path string) (fs.File, error) {
|
||||
var f fs.File
|
||||
err := r.execute(path, func(root *os.Root, relPath string) error {
|
||||
file, err := root.Open(relPath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return fmt.Errorf("failed to open file: file not found: %w", err)
|
||||
}
|
||||
if os.IsPermission(err) || strings.Contains(err.Error(), "escapes from parent") ||
|
||||
strings.Contains(err.Error(), "permission denied") {
|
||||
return fmt.Errorf("failed to open file: access denied: %w", err)
|
||||
}
|
||||
return fmt.Errorf("failed to open file: %w", err)
|
||||
}
|
||||
f = file
|
||||
return nil
|
||||
})
|
||||
return f, err
|
||||
}
|
||||
|
||||
// whitelistFs wraps a sandboxFs and allows access to specific paths outside
|
||||
// the workspace when they match any of the provided patterns.
|
||||
type whitelistFs struct {
|
||||
@@ -427,6 +491,13 @@ func (w *whitelistFs) ReadDir(path string) ([]os.DirEntry, error) {
|
||||
return w.sandbox.ReadDir(path)
|
||||
}
|
||||
|
||||
func (w *whitelistFs) Open(path string) (fs.File, error) {
|
||||
if w.matches(path) {
|
||||
return w.host.Open(path)
|
||||
}
|
||||
return w.sandbox.Open(path)
|
||||
}
|
||||
|
||||
// buildFs returns the appropriate fileSystem implementation based on restriction
|
||||
// settings and optional path whitelist patterns.
|
||||
func buildFs(workspace string, restrict bool, patterns []*regexp.Regexp) fileSystem {
|
||||
@@ -461,3 +532,33 @@ func getSafeRelPath(workspace, path string) (string, error) {
|
||||
|
||||
return rel, nil
|
||||
}
|
||||
|
||||
// isBinaryFile uses common heuristics to determine if the content is a binary file.
|
||||
func isBinaryFile(content []byte) bool {
|
||||
if len(content) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
// Sample the first 512 bytes (or less if the file is smaller)
|
||||
limit := len(content)
|
||||
if limit > 512 {
|
||||
limit = 512
|
||||
}
|
||||
sample := content[:limit]
|
||||
|
||||
// Check for NUL bytes in the sample (standard binary detection)
|
||||
if bytes.IndexByte(sample, 0) != -1 {
|
||||
return true
|
||||
}
|
||||
|
||||
// Use standard library content type detection to catch specific formats like PDF
|
||||
contentType := http.DetectContentType(sample)
|
||||
if contentType == "application/pdf" ||
|
||||
strings.HasPrefix(contentType, "image/") ||
|
||||
strings.HasPrefix(contentType, "video/") ||
|
||||
strings.HasPrefix(contentType, "audio/") {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -59,7 +59,7 @@ func TestFilesystemTool_ReadFile_NotFound(t *testing.T) {
|
||||
}
|
||||
|
||||
// Should contain error message
|
||||
if !strings.Contains(result.ForLLM, "failed to read") && !strings.Contains(result.ForUser, "failed to read") {
|
||||
if !strings.Contains(result.ForLLM, "failed to open file") && !strings.Contains(result.ForUser, "failed to read") {
|
||||
t.Errorf("Expected error message, got ForLLM: %s, ForUser: %s", result.ForLLM, result.ForUser)
|
||||
}
|
||||
}
|
||||
@@ -520,3 +520,93 @@ func TestWhitelistFs_AllowsMatchingPaths(t *testing.T) {
|
||||
t.Errorf("expected non-whitelisted path to be blocked, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsBinaryFile(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
content []byte
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "empty content",
|
||||
content: []byte(""),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "plain text",
|
||||
content: []byte("This is a normal text file with punctuation and 12345 numbers."),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "contains null byte",
|
||||
content: []byte("plain text\x00followed by a null byte"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "pdf header",
|
||||
content: []byte("%PDF-1.4\n%\xE2\xE3\xCF\xD3\n1 0 obj\n<</Type/Catalog/Pages 2 0 R>>"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "png magic bytes",
|
||||
content: []byte("\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x01\x00"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "jpeg magic bytes",
|
||||
content: []byte("\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x01\x00H"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "html text (not binary)",
|
||||
content: []byte("<!DOCTYPE html><html><body><h1>Ciao</h1></body></html>"),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "json text (not binary)",
|
||||
content: []byte(`{"key": "value", "number": 42}`),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "markdown text (not binary)",
|
||||
content: []byte("# Markdown Title\n\nThis is a **bold text** and a [link](https://example.com)."),
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := isBinaryFile(tt.content)
|
||||
if result != tt.expected {
|
||||
t.Errorf("isBinaryFile() for %q returned %v, expected %v", tt.name, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilesystemTool_ReadFile_BlocksBinary(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
// Create a dummy binary file (e.g., a PDF).
|
||||
testFile := filepath.Join(tmpDir, "fake_document.pdf")
|
||||
fakePDFContent := []byte("%PDF-1.4\n% Some null test bytes\x00\x00\x00")
|
||||
os.WriteFile(testFile, fakePDFContent, 0o644)
|
||||
|
||||
tool := NewReadFileTool(tmpDir, true)
|
||||
ctx := context.Background()
|
||||
args := map[string]any{
|
||||
"path": testFile,
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
if !result.IsError {
|
||||
t.Errorf("An error was expected when trying to read a binary file, but instead it was successful")
|
||||
}
|
||||
|
||||
// The error should mention that it is a binary file
|
||||
expectedMsg := "appears to be a binary file"
|
||||
if !strings.Contains(result.ForLLM, expectedMsg) {
|
||||
t.Errorf("The error message '%s' was expected, obtained: %s", expectedMsg, result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user