From e931756feef9cf22353759e969e9171aac8eb4a6 Mon Sep 17 00:00:00 2001 From: Mauro Date: Thu, 19 Mar 2026 04:22:52 +0100 Subject: [PATCH] feat(tool): overwrite flag in write_file (#1761) * feat: overwrite flag in write file tool * fix error message --- pkg/tools/filesystem.go | 15 ++++- pkg/tools/filesystem_test.go | 115 +++++++++++++++++++++++++++++++++++ 2 files changed, 129 insertions(+), 1 deletion(-) diff --git a/pkg/tools/filesystem.go b/pkg/tools/filesystem.go index ae356f248..39d45013d 100644 --- a/pkg/tools/filesystem.go +++ b/pkg/tools/filesystem.go @@ -496,7 +496,7 @@ func (t *WriteFileTool) Name() string { } func (t *WriteFileTool) Description() string { - return "Write content to a file" + return "Write content to a file. If the file already exists, you must set overwrite=true to replace it." } func (t *WriteFileTool) Parameters() map[string]any { @@ -511,6 +511,11 @@ func (t *WriteFileTool) Parameters() map[string]any { "type": "string", "description": "Content to write to the file", }, + "overwrite": map[string]any{ + "type": "boolean", + "description": "Must be set to true to overwrite an existing file.", + "default": false, + }, }, "required": []string{"path", "content"}, } @@ -527,6 +532,14 @@ func (t *WriteFileTool) Execute(ctx context.Context, args map[string]any) *ToolR return ErrorResult("content is required") } + overwrite, _ := args["overwrite"].(bool) + + if !overwrite { + if _, err := t.fs.Open(path); err == nil { + return ErrorResult(fmt.Sprintf("file: %s already exists. Set overwrite=true to replace.", path)) + } + } + if err := t.fs.WriteFile(path, []byte(content)); err != nil { return ErrorResult(err.Error()) } diff --git a/pkg/tools/filesystem_test.go b/pkg/tools/filesystem_test.go index 5ebf38df2..0b4dd310b 100644 --- a/pkg/tools/filesystem_test.go +++ b/pkg/tools/filesystem_test.go @@ -189,6 +189,121 @@ func TestFilesystemTool_WriteFile_MissingContent(t *testing.T) { } } +// TestFilesystemTool_WriteFile_OverwriteDefaultBlocked verifies that writing to an +// existing file without overwrite=true returns an error. +func TestFilesystemTool_WriteFile_OverwriteDefaultBlocked(t *testing.T) { + tmpDir := t.TempDir() + testFile := filepath.Join(tmpDir, "existing.txt") + os.WriteFile(testFile, []byte("original"), 0o644) + + tool := NewWriteFileTool("", false) + result := tool.Execute(context.Background(), map[string]any{ + "path": testFile, + "content": "new content", + }) + + assert.True(t, result.IsError, "expected error when overwriting without overwrite=true") + assert.Contains(t, result.ForLLM, "already exists") + assert.Contains(t, result.ForLLM, "overwrite=true") + + // Original content must be untouched + data, err := os.ReadFile(testFile) + assert.NoError(t, err) + assert.Equal(t, "original", string(data)) +} + +// TestFilesystemTool_WriteFile_OverwriteExplicitAllowed verifies that setting +// overwrite=true replaces the existing file. +func TestFilesystemTool_WriteFile_OverwriteExplicitAllowed(t *testing.T) { + tmpDir := t.TempDir() + testFile := filepath.Join(tmpDir, "existing.txt") + os.WriteFile(testFile, []byte("original"), 0o644) + + tool := NewWriteFileTool("", false) + result := tool.Execute(context.Background(), map[string]any{ + "path": testFile, + "content": "replaced", + "overwrite": true, + }) + + assert.False(t, result.IsError, "expected success with overwrite=true, got: %s", result.ForLLM) + + data, err := os.ReadFile(testFile) + assert.NoError(t, err) + assert.Equal(t, "replaced", string(data)) +} + +// TestFilesystemTool_WriteFile_NewFileNoOverwriteFlag verifies that a new (non-existing) +// file can be written without setting overwrite=true. +func TestFilesystemTool_WriteFile_NewFileNoOverwriteFlag(t *testing.T) { + tmpDir := t.TempDir() + testFile := filepath.Join(tmpDir, "newfile.txt") + + tool := NewWriteFileTool("", false) + result := tool.Execute(context.Background(), map[string]any{ + "path": testFile, + "content": "brand new", + }) + + assert.False(t, result.IsError, "expected success for new file, got: %s", result.ForLLM) + + data, err := os.ReadFile(testFile) + assert.NoError(t, err) + assert.Equal(t, "brand new", string(data)) +} + +// TestFilesystemTool_WriteFile_OverwriteFalseExplicitBlocked verifies that +// explicitly passing overwrite=false also blocks overwriting. +func TestFilesystemTool_WriteFile_OverwriteFalseExplicitBlocked(t *testing.T) { + tmpDir := t.TempDir() + testFile := filepath.Join(tmpDir, "existing.txt") + os.WriteFile(testFile, []byte("original"), 0o644) + + tool := NewWriteFileTool("", false) + result := tool.Execute(context.Background(), map[string]any{ + "path": testFile, + "content": "new content", + "overwrite": false, + }) + + assert.True(t, result.IsError, "expected error when overwrite=false") + assert.Contains(t, result.ForLLM, "already exists") + + data, err := os.ReadFile(testFile) + assert.NoError(t, err) + assert.Equal(t, "original", string(data)) +} + +// TestFilesystemTool_WriteFile_OverwriteSandboxed verifies the overwrite guard +// works correctly in restricted (sandbox) mode. +func TestFilesystemTool_WriteFile_OverwriteSandboxed(t *testing.T) { + workspace := t.TempDir() + testFile := "file.txt" + os.WriteFile(filepath.Join(workspace, testFile), []byte("original"), 0o644) + + tool := NewWriteFileTool(workspace, true) + + // Without overwrite=true → blocked + result := tool.Execute(context.Background(), map[string]any{ + "path": testFile, + "content": "new content", + }) + assert.True(t, result.IsError, "expected error in sandbox mode without overwrite=true") + assert.Contains(t, result.ForLLM, "already exists") + + // With overwrite=true → allowed + result = tool.Execute(context.Background(), map[string]any{ + "path": testFile, + "content": "replaced in sandbox", + "overwrite": true, + }) + assert.False(t, result.IsError, "expected success in sandbox mode with overwrite=true, got: %s", result.ForLLM) + + data, err := os.ReadFile(filepath.Join(workspace, testFile)) + assert.NoError(t, err) + assert.Equal(t, "replaced in sandbox", string(data)) +} + // TestFilesystemTool_ListDir_Success verifies successful directory listing func TestFilesystemTool_ListDir_Success(t *testing.T) { tmpDir := t.TempDir()