diff --git a/pkg/agent/instance.go b/pkg/agent/instance.go index 0c7baa1ee..1c3635322 100644 --- a/pkg/agent/instance.go +++ b/pkg/agent/instance.go @@ -10,6 +10,7 @@ import ( "strings" "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/media" "github.com/sipeed/picoclaw/pkg/memory" "github.com/sipeed/picoclaw/pkg/providers" "github.com/sipeed/picoclaw/pkg/routing" @@ -66,7 +67,7 @@ func NewAgentInstance( readRestrict := restrict && !defaults.AllowReadOutsideWorkspace // Compile path whitelist patterns from config. - allowReadPaths := compilePatterns(cfg.Tools.AllowReadPaths) + allowReadPaths := buildAllowReadPatterns(cfg) allowWritePaths := compilePatterns(cfg.Tools.AllowWritePaths) toolsRegistry := tools.NewToolRegistry() @@ -82,7 +83,7 @@ func NewAgentInstance( toolsRegistry.Register(tools.NewListDirTool(workspace, readRestrict, allowReadPaths)) } if cfg.Tools.IsToolEnabled("exec") { - execTool, err := tools.NewExecToolWithConfig(workspace, restrict, cfg) + execTool, err := tools.NewExecToolWithConfig(workspace, restrict, cfg, allowReadPaths) if err != nil { log.Fatalf("Critical error: unable to initialize exec tool: %v", err) } @@ -282,6 +283,28 @@ func compilePatterns(patterns []string) []*regexp.Regexp { return compiled } +func buildAllowReadPatterns(cfg *config.Config) []*regexp.Regexp { + var configured []string + if cfg != nil { + configured = cfg.Tools.AllowReadPaths + } + + compiled := compilePatterns(configured) + mediaDirPattern := regexp.MustCompile(mediaTempDirPattern()) + for _, pattern := range compiled { + if pattern.String() == mediaDirPattern.String() { + return compiled + } + } + + return append(compiled, mediaDirPattern) +} + +func mediaTempDirPattern() string { + sep := regexp.QuoteMeta(string(os.PathSeparator)) + return "^" + regexp.QuoteMeta(filepath.Clean(media.TempDir())) + "(?:" + sep + "|$)" +} + // Close releases resources held by the agent's session store. func (a *AgentInstance) Close() error { if a.Sessions != nil { diff --git a/pkg/agent/instance_test.go b/pkg/agent/instance_test.go index 4f41ecd1c..f8057bb2f 100644 --- a/pkg/agent/instance_test.go +++ b/pkg/agent/instance_test.go @@ -1,10 +1,14 @@ package agent import ( + "context" "os" + "path/filepath" + "strings" "testing" "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/media" ) func TestNewAgentInstance_UsesDefaultsTemperatureAndMaxTokens(t *testing.T) { @@ -160,3 +164,85 @@ func TestNewAgentInstance_ResolveCandidatesFromModelListAlias(t *testing.T) { }) } } + +func TestNewAgentInstance_AllowsMediaTempDirForReadListAndExec(t *testing.T) { + workspace := t.TempDir() + mediaDir := media.TempDir() + if err := os.MkdirAll(mediaDir, 0o700); err != nil { + t.Fatalf("MkdirAll(mediaDir) error = %v", err) + } + + mediaFile, err := os.CreateTemp(mediaDir, "instance-tool-*.txt") + if err != nil { + t.Fatalf("CreateTemp(mediaDir) error = %v", err) + } + mediaPath := mediaFile.Name() + if _, err := mediaFile.WriteString("attachment content"); err != nil { + mediaFile.Close() + t.Fatalf("WriteString(mediaFile) error = %v", err) + } + if err := mediaFile.Close(); err != nil { + t.Fatalf("Close(mediaFile) error = %v", err) + } + t.Cleanup(func() { _ = os.Remove(mediaPath) }) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: workspace, + Model: "test-model", + RestrictToWorkspace: true, + }, + }, + Tools: config.ToolsConfig{ + ReadFile: config.ReadFileToolConfig{Enabled: true}, + ListDir: config.ToolConfig{Enabled: true}, + Exec: config.ExecConfig{ + ToolConfig: config.ToolConfig{Enabled: true}, + EnableDenyPatterns: true, + AllowRemote: true, + }, + }, + } + + agent := NewAgentInstance(nil, &cfg.Agents.Defaults, cfg, &mockProvider{}) + + readTool, ok := agent.Tools.Get("read_file") + if !ok { + t.Fatal("read_file tool not registered") + } + readResult := readTool.Execute(context.Background(), map[string]any{"path": mediaPath}) + if readResult.IsError { + t.Fatalf("read_file should allow media temp dir, got: %s", readResult.ForLLM) + } + if !strings.Contains(readResult.ForLLM, "attachment content") { + t.Fatalf("read_file output missing media content: %s", readResult.ForLLM) + } + + listTool, ok := agent.Tools.Get("list_dir") + if !ok { + t.Fatal("list_dir tool not registered") + } + listResult := listTool.Execute(context.Background(), map[string]any{"path": mediaDir}) + if listResult.IsError { + t.Fatalf("list_dir should allow media temp dir, got: %s", listResult.ForLLM) + } + if !strings.Contains(listResult.ForLLM, filepath.Base(mediaPath)) { + t.Fatalf("list_dir output missing media file: %s", listResult.ForLLM) + } + + execTool, ok := agent.Tools.Get("exec") + if !ok { + t.Fatal("exec tool not registered") + } + execResult := execTool.Execute(context.Background(), map[string]any{ + "command": "cat " + filepath.Base(mediaPath), + "working_dir": mediaDir, + }) + if execResult.IsError { + t.Fatalf("exec should allow media temp dir, got: %s", execResult.ForLLM) + } + if !strings.Contains(execResult.ForLLM, "attachment content") { + t.Fatalf("exec output missing media content: %s", execResult.ForLLM) + } +} diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index dfa339dee..8a0303b50 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -117,6 +117,8 @@ func registerSharedTools( registry *AgentRegistry, provider providers.LLMProvider, ) { + allowReadPaths := buildAllowReadPatterns(cfg) + for _, agentID := range registry.ListAgentIDs() { agent, ok := registry.GetAgent(agentID) if !ok { @@ -195,6 +197,7 @@ func registerSharedTools( cfg.Agents.Defaults.RestrictToWorkspace, cfg.Agents.Defaults.GetMaxMediaSize(), nil, + allowReadPaths, ) agent.Tools.Register(sendFileTool) } diff --git a/pkg/tools/filesystem.go b/pkg/tools/filesystem.go index 6b1cb1475..21385b01b 100644 --- a/pkg/tools/filesystem.go +++ b/pkg/tools/filesystem.go @@ -22,6 +22,10 @@ const MaxReadFileSize = 64 * 1024 // 64KB limit to avoid context overflow // validatePath ensures the given path is within the workspace if restrict is true. func validatePath(path, workspace string, restrict bool) (string, error) { + return validatePathWithAllowPaths(path, workspace, restrict, nil) +} + +func validatePathWithAllowPaths(path, workspace string, restrict bool, patterns []*regexp.Regexp) (string, error) { if workspace == "" { return path, fmt.Errorf("workspace is not defined") } @@ -42,6 +46,10 @@ func validatePath(path, workspace string, restrict bool) (string, error) { } if restrict { + if isAllowedPath(absPath, patterns) { + return absPath, nil + } + if !isWithinWorkspace(absPath, absWorkspace) { return "", fmt.Errorf("access denied: path is outside the workspace") } @@ -73,6 +81,39 @@ func validatePath(path, workspace string, restrict bool) (string, error) { return absPath, nil } +func isAllowedPath(path string, patterns []*regexp.Regexp) bool { + if len(patterns) == 0 { + return false + } + + cleaned := filepath.Clean(path) + if !matchesAllowedPath(cleaned, patterns) { + return false + } + + resolved, err := filepath.EvalSymlinks(cleaned) + if err == nil { + return matchesAllowedPath(resolved, patterns) + } + if os.IsNotExist(err) { + parentResolved, parentErr := resolveExistingAncestor(filepath.Dir(cleaned)) + if parentErr == nil { + return matchesAllowedPath(parentResolved, patterns) + } + } + + return false +} + +func matchesAllowedPath(path string, patterns []*regexp.Regexp) bool { + for _, pattern := range patterns { + if pattern.MatchString(path) { + return true + } + } + return false +} + func resolveExistingAncestor(path string) (string, error) { for current := filepath.Clean(path); ; current = filepath.Dir(current) { if resolved, err := filepath.EvalSymlinks(current); err == nil { @@ -625,12 +666,7 @@ type whitelistFs struct { } func (w *whitelistFs) matches(path string) bool { - for _, p := range w.patterns { - if p.MatchString(path) { - return true - } - } - return false + return matchesAllowedPath(path, w.patterns) } func (w *whitelistFs) ReadFile(path string) ([]byte, error) { diff --git a/pkg/tools/send_file.go b/pkg/tools/send_file.go index 1a03e58ed..a67bd4210 100644 --- a/pkg/tools/send_file.go +++ b/pkg/tools/send_file.go @@ -6,6 +6,7 @@ import ( "mime" "os" "path/filepath" + "regexp" "strings" "github.com/h2non/filetype" @@ -21,20 +22,32 @@ type SendFileTool struct { restrict bool maxFileSize int mediaStore media.MediaStore + allowPaths []*regexp.Regexp defaultChannel string defaultChatID string } -func NewSendFileTool(workspace string, restrict bool, maxFileSize int, store media.MediaStore) *SendFileTool { +func NewSendFileTool( + workspace string, + restrict bool, + maxFileSize int, + store media.MediaStore, + allowPaths ...[]*regexp.Regexp, +) *SendFileTool { if maxFileSize <= 0 { maxFileSize = config.DefaultMaxMediaSize } + var patterns []*regexp.Regexp + if len(allowPaths) > 0 { + patterns = allowPaths[0] + } return &SendFileTool{ workspace: workspace, restrict: restrict, maxFileSize: maxFileSize, mediaStore: store, + allowPaths: patterns, } } @@ -92,7 +105,7 @@ func (t *SendFileTool) Execute(ctx context.Context, args map[string]any) *ToolRe return ErrorResult("media store not configured") } - resolved, err := validatePath(path, t.workspace, t.restrict) + resolved, err := validatePathWithAllowPaths(path, t.workspace, t.restrict, t.allowPaths) if err != nil { return ErrorResult(fmt.Sprintf("invalid path: %v", err)) } diff --git a/pkg/tools/send_file_test.go b/pkg/tools/send_file_test.go index 08d129674..6daaab31c 100644 --- a/pkg/tools/send_file_test.go +++ b/pkg/tools/send_file_test.go @@ -4,6 +4,7 @@ import ( "context" "os" "path/filepath" + "regexp" "strings" "testing" @@ -128,6 +129,44 @@ func TestSendFileTool_CustomFilename(t *testing.T) { } } +func TestSendFileTool_AllowsWhitelistedMediaTempPath(t *testing.T) { + workspace := t.TempDir() + mediaDir := media.TempDir() + if err := os.MkdirAll(mediaDir, 0o700); err != nil { + t.Fatalf("MkdirAll(mediaDir) error = %v", err) + } + + testFile, err := os.CreateTemp(mediaDir, "send-file-*.txt") + if err != nil { + t.Fatalf("CreateTemp(mediaDir) error = %v", err) + } + testPath := testFile.Name() + if _, err := testFile.WriteString("forward me"); err != nil { + testFile.Close() + t.Fatalf("WriteString(testFile) error = %v", err) + } + if err := testFile.Close(); err != nil { + t.Fatalf("Close(testFile) error = %v", err) + } + t.Cleanup(func() { _ = os.Remove(testPath) }) + + pattern := regexp.MustCompile( + "^" + regexp.QuoteMeta(filepath.Clean(mediaDir)) + "(?:" + regexp.QuoteMeta(string(os.PathSeparator)) + "|$)", + ) + + store := media.NewFileMediaStore() + tool := NewSendFileTool(workspace, true, 0, store, []*regexp.Regexp{pattern}) + tool.SetContext("feishu", "chat123") + + result := tool.Execute(context.Background(), map[string]any{"path": testPath}) + if result.IsError { + t.Fatalf("expected whitelisted temp media file to be sendable, got: %s", result.ForLLM) + } + if len(result.Media) != 1 { + t.Fatalf("expected 1 media ref, got %d", len(result.Media)) + } +} + func TestDetectMediaType_MagicBytes(t *testing.T) { dir := t.TempDir() diff --git a/pkg/tools/shell.go b/pkg/tools/shell.go index 9ea05bb12..0dc85ae21 100644 --- a/pkg/tools/shell.go +++ b/pkg/tools/shell.go @@ -23,6 +23,7 @@ type ExecTool struct { denyPatterns []*regexp.Regexp allowPatterns []*regexp.Regexp customAllowPatterns []*regexp.Regexp + allowedPathPatterns []*regexp.Regexp restrictToWorkspace bool allowRemote bool } @@ -95,14 +96,23 @@ var ( } ) -func NewExecTool(workingDir string, restrict bool) (*ExecTool, error) { - return NewExecToolWithConfig(workingDir, restrict, nil) +func NewExecTool(workingDir string, restrict bool, allowPaths ...[]*regexp.Regexp) (*ExecTool, error) { + return NewExecToolWithConfig(workingDir, restrict, nil, allowPaths...) } -func NewExecToolWithConfig(workingDir string, restrict bool, config *config.Config) (*ExecTool, error) { +func NewExecToolWithConfig( + workingDir string, + restrict bool, + config *config.Config, + allowPaths ...[]*regexp.Regexp, +) (*ExecTool, error) { denyPatterns := make([]*regexp.Regexp, 0) customAllowPatterns := make([]*regexp.Regexp, 0) + var allowedPathPatterns []*regexp.Regexp allowRemote := true + if len(allowPaths) > 0 { + allowedPathPatterns = allowPaths[0] + } if config != nil { execConfig := config.Tools.Exec @@ -146,6 +156,7 @@ func NewExecToolWithConfig(workingDir string, restrict bool, config *config.Conf denyPatterns: denyPatterns, allowPatterns: nil, customAllowPatterns: customAllowPatterns, + allowedPathPatterns: allowedPathPatterns, restrictToWorkspace: restrict, allowRemote: allowRemote, }, nil @@ -198,7 +209,7 @@ func (t *ExecTool) Execute(ctx context.Context, args map[string]any) *ToolResult cwd := t.workingDir if wd, ok := args["working_dir"].(string); ok && wd != "" { if t.restrictToWorkspace && t.workingDir != "" { - resolvedWD, err := validatePath(wd, t.workingDir, true) + resolvedWD, err := validatePathWithAllowPaths(wd, t.workingDir, true, t.allowedPathPatterns) if err != nil { return ErrorResult("Command blocked by safety guard (" + err.Error() + ")") } @@ -226,16 +237,20 @@ func (t *ExecTool) Execute(ctx context.Context, args map[string]any) *ToolResult if err != nil { return ErrorResult(fmt.Sprintf("Command blocked by safety guard (path resolution failed: %v)", err)) } - absWorkspace, _ := filepath.Abs(t.workingDir) - wsResolved, _ := filepath.EvalSymlinks(absWorkspace) - if wsResolved == "" { - wsResolved = absWorkspace + if isAllowedPath(resolved, t.allowedPathPatterns) { + cwd = resolved + } else { + absWorkspace, _ := filepath.Abs(t.workingDir) + wsResolved, _ := filepath.EvalSymlinks(absWorkspace) + if wsResolved == "" { + wsResolved = absWorkspace + } + rel, err := filepath.Rel(wsResolved, resolved) + if err != nil || !filepath.IsLocal(rel) { + return ErrorResult("Command blocked by safety guard (working directory escaped workspace)") + } + cwd = resolved } - rel, err := filepath.Rel(wsResolved, resolved) - if err != nil || !filepath.IsLocal(rel) { - return ErrorResult("Command blocked by safety guard (working directory escaped workspace)") - } - cwd = resolved } // timeout == 0 means no timeout @@ -412,6 +427,9 @@ func (t *ExecTool) guardCommand(command, cwd string) string { if safePaths[p] { continue } + if isAllowedPath(p, t.allowedPathPatterns) { + continue + } rel, err := filepath.Rel(cwdPath, p) if err != nil {