diff --git a/pkg/tools/filesystem.go b/pkg/tools/filesystem.go index d25ec1254..92946ef98 100644 --- a/pkg/tools/filesystem.go +++ b/pkg/tools/filesystem.go @@ -82,22 +82,19 @@ func isAllowedPath(path string, patterns []*regexp.Regexp) bool { } cleaned := filepath.Clean(path) + if !filepath.IsAbs(cleaned) { + return false + } if !matchesAllowedPath(cleaned, patterns) { return false } - resolved, err := filepath.EvalSymlinks(cleaned) - if err == nil { - return matchesAllowedPath(resolved, patterns) - } - if os.IsNotExist(err) { - parentResolved, parentErr := resolveExistingAncestor(filepath.Dir(cleaned)) - if parentErr == nil { - return matchesAllowedPath(parentResolved, patterns) - } + resolved, err := resolvePathAgainstExistingAncestor(cleaned) + if err != nil { + return false } - return false + return matchesAllowedPath(resolved, patterns) } func matchesAllowedPath(path string, patterns []*regexp.Regexp) bool { @@ -122,6 +119,29 @@ func resolveExistingAncestor(path string) (string, error) { } } +func resolvePathAgainstExistingAncestor(path string) (string, error) { + cleaned := filepath.Clean(path) + for current := cleaned; ; current = filepath.Dir(current) { + resolved, err := filepath.EvalSymlinks(current) + if err == nil { + suffix, relErr := filepath.Rel(current, cleaned) + if relErr != nil { + return "", relErr + } + if suffix == "." { + return filepath.Clean(resolved), nil + } + return filepath.Clean(filepath.Join(resolved, suffix)), nil + } + if !os.IsNotExist(err) { + return "", err + } + if filepath.Dir(current) == current { + return "", os.ErrNotExist + } + } +} + func isWithinWorkspace(candidate, workspace string) bool { rel, err := filepath.Rel(filepath.Clean(workspace), filepath.Clean(candidate)) return err == nil && filepath.IsLocal(rel) @@ -661,7 +681,7 @@ type whitelistFs struct { } func (w *whitelistFs) matches(path string) bool { - return matchesAllowedPath(path, w.patterns) + return isAllowedPath(path, w.patterns) } func (w *whitelistFs) ReadFile(path string) ([]byte, error) { diff --git a/pkg/tools/filesystem_test.go b/pkg/tools/filesystem_test.go index 0bbf6caf0..78d69273f 100644 --- a/pkg/tools/filesystem_test.go +++ b/pkg/tools/filesystem_test.go @@ -521,6 +521,55 @@ func TestWhitelistFs_AllowsMatchingPaths(t *testing.T) { } } +func TestWhitelistFs_BlocksSymlinkEscapeInAllowedDir(t *testing.T) { + workspace := t.TempDir() + allowedDir := t.TempDir() + secretDir := t.TempDir() + secretFile := filepath.Join(secretDir, "secret.txt") + if err := os.WriteFile(secretFile, []byte("top secret"), 0o644); err != nil { + t.Fatalf("WriteFile(secretFile) error = %v", err) + } + + linkPath := filepath.Join(allowedDir, "link_out") + if err := os.Symlink(secretDir, linkPath); err != nil { + t.Skipf("symlink not supported in this environment: %v", err) + } + + patterns := []*regexp.Regexp{regexp.MustCompile(`^` + regexp.QuoteMeta(allowedDir))} + tool := NewReadFileTool(workspace, true, MaxReadFileSize, patterns) + + result := tool.Execute(context.Background(), map[string]any{"path": filepath.Join(linkPath, "secret.txt")}) + if !result.IsError { + t.Fatalf("expected symlink escape from allowed dir to be blocked, got: %s", result.ForLLM) + } +} + +func TestWhitelistFs_WriteAllowsNewFileUnderAllowedDir(t *testing.T) { + workspace := t.TempDir() + rootDir := t.TempDir() + allowedDir := filepath.Join(rootDir, "allowed") + targetFile := filepath.Join(allowedDir, "nested", "file.txt") + + patterns := []*regexp.Regexp{regexp.MustCompile(`^` + regexp.QuoteMeta(allowedDir))} + tool := NewWriteFileTool(workspace, true, patterns) + + result := tool.Execute(context.Background(), map[string]any{ + "path": targetFile, + "content": "outside write", + }) + if result.IsError { + t.Fatalf("expected whitelisted write to succeed, got: %s", result.ForLLM) + } + + data, err := os.ReadFile(targetFile) + if err != nil { + t.Fatalf("ReadFile(targetFile) error = %v", err) + } + if string(data) != "outside write" { + t.Fatalf("target file content = %q, want %q", string(data), "outside write") + } +} + // TestReadFileTool_ChunkedReading verifies the pagination logic of the tool // by reading a file in multiple chunks using 'offset' and 'length'. func TestReadFileTool_ChunkedReading(t *testing.T) {