Files
picoclaw/pkg/tools/shell.go
T
I Putu Eddy Irawan ee5b61884a fix: migration ModelName, reasoning_content, shell regex, loop boundary
1. migration.go: Set ModelName to userModel when provider matches so
   GetModelConfig(userModel) can find the entry. Previously the migration
   created entries with the provider name as ModelName (e.g. "moonshot")
   but lookup used the model name (e.g. "k2p5"), causing "model not found".

2. openai_compat/provider.go: Preserve reasoning_content in conversation
   history. Thinking models (e.g. Kimi K2, DeepSeek-R1) return
   reasoning_content which must be echoed back. Without it, APIs return
   400: "thinking is enabled but reasoning_content is missing".

3. shell.go: Fix deny pattern regex for format/mkfs/diskpart to use
   (?:^|\s) instead of \b to avoid matching --format flags.
   Fix path extraction regex to use submatch to avoid matching flags
   like -rf as paths.

4. loop.go: Adjust forceCompression mid-point to avoid splitting
   tool-call/result message pairs, which causes API errors.
2026-03-01 08:44:15 +07:00

333 lines
8.3 KiB
Go

package tools
import (
"bytes"
"context"
"errors"
"fmt"
"os"
"os/exec"
"path/filepath"
"regexp"
"runtime"
"strings"
"time"
"github.com/sipeed/picoclaw/pkg/config"
)
type ExecTool struct {
workingDir string
timeout time.Duration
denyPatterns []*regexp.Regexp
allowPatterns []*regexp.Regexp
restrictToWorkspace bool
}
var defaultDenyPatterns = []*regexp.Regexp{
regexp.MustCompile(`\brm\s+-[rf]{1,2}\b`),
regexp.MustCompile(`\bdel\s+/[fq]\b`),
regexp.MustCompile(`\brmdir\s+/s\b`),
regexp.MustCompile(`(?:^|\s)(format|mkfs|diskpart)\s`), // Match disk wiping commands, avoid matching --format flags
regexp.MustCompile(`\bdd\s+if=`),
regexp.MustCompile(`>\s*/dev/sd[a-z]\b`), // Block writes to disk devices (but allow /dev/null)
regexp.MustCompile(`\b(shutdown|reboot|poweroff)\b`),
regexp.MustCompile(`:\(\)\s*\{.*\};\s*:`),
regexp.MustCompile(`\$\([^)]+\)`),
regexp.MustCompile(`\$\{[^}]+\}`),
regexp.MustCompile("`[^`]+`"),
regexp.MustCompile(`\|\s*sh\b`),
regexp.MustCompile(`\|\s*bash\b`),
regexp.MustCompile(`;\s*rm\s+-[rf]`),
regexp.MustCompile(`&&\s*rm\s+-[rf]`),
regexp.MustCompile(`\|\|\s*rm\s+-[rf]`),
regexp.MustCompile(`>\s*/dev/null\s*>&?\s*\d?`),
regexp.MustCompile(`<<\s*EOF`),
regexp.MustCompile(`\$\(\s*cat\s+`),
regexp.MustCompile(`\$\(\s*curl\s+`),
regexp.MustCompile(`\$\(\s*wget\s+`),
regexp.MustCompile(`\$\(\s*which\s+`),
regexp.MustCompile(`\bsudo\b`),
regexp.MustCompile(`\bchmod\s+[0-7]{3,4}\b`),
regexp.MustCompile(`\bchown\b`),
regexp.MustCompile(`\bpkill\b`),
regexp.MustCompile(`\bkillall\b`),
regexp.MustCompile(`\bkill\s+-[9]\b`),
regexp.MustCompile(`\bcurl\b.*\|\s*(sh|bash)`),
regexp.MustCompile(`\bwget\b.*\|\s*(sh|bash)`),
regexp.MustCompile(`\bnpm\s+install\s+-g\b`),
regexp.MustCompile(`\bpip\s+install\s+--user\b`),
regexp.MustCompile(`\bapt\s+(install|remove|purge)\b`),
regexp.MustCompile(`\byum\s+(install|remove)\b`),
regexp.MustCompile(`\bdnf\s+(install|remove)\b`),
regexp.MustCompile(`\bdocker\s+run\b`),
regexp.MustCompile(`\bdocker\s+exec\b`),
regexp.MustCompile(`\bgit\s+push\b`),
regexp.MustCompile(`\bgit\s+force\b`),
regexp.MustCompile(`\bssh\b.*@`),
regexp.MustCompile(`\beval\b`),
regexp.MustCompile(`\bsource\s+.*\.sh\b`),
}
func NewExecTool(workingDir string, restrict bool) (*ExecTool, error) {
return NewExecToolWithConfig(workingDir, restrict, nil)
}
func NewExecToolWithConfig(workingDir string, restrict bool, config *config.Config) (*ExecTool, error) {
denyPatterns := make([]*regexp.Regexp, 0)
if config != nil {
execConfig := config.Tools.Exec
enableDenyPatterns := execConfig.EnableDenyPatterns
if enableDenyPatterns {
denyPatterns = append(denyPatterns, defaultDenyPatterns...)
if len(execConfig.CustomDenyPatterns) > 0 {
fmt.Printf("Using custom deny patterns: %v\n", execConfig.CustomDenyPatterns)
for _, pattern := range execConfig.CustomDenyPatterns {
re, err := regexp.Compile(pattern)
if err != nil {
return nil, fmt.Errorf("invalid custom deny pattern %q: %w", pattern, err)
}
denyPatterns = append(denyPatterns, re)
}
}
} else {
// If deny patterns are disabled, we won't add any patterns, allowing all commands.
fmt.Println("Warning: deny patterns are disabled. All commands will be allowed.")
}
} else {
denyPatterns = append(denyPatterns, defaultDenyPatterns...)
}
return &ExecTool{
workingDir: workingDir,
timeout: 60 * time.Second,
denyPatterns: denyPatterns,
allowPatterns: nil,
restrictToWorkspace: restrict,
}, nil
}
func (t *ExecTool) Name() string {
return "exec"
}
func (t *ExecTool) Description() string {
return "Execute a shell command and return its output. Use with caution."
}
func (t *ExecTool) Parameters() map[string]any {
return map[string]any{
"type": "object",
"properties": map[string]any{
"command": map[string]any{
"type": "string",
"description": "The shell command to execute",
},
"working_dir": map[string]any{
"type": "string",
"description": "Optional working directory for the command",
},
},
"required": []string{"command"},
}
}
func (t *ExecTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
command, ok := args["command"].(string)
if !ok {
return ErrorResult("command is required")
}
cwd := t.workingDir
if wd, ok := args["working_dir"].(string); ok && wd != "" {
if t.restrictToWorkspace && t.workingDir != "" {
resolvedWD, err := validatePath(wd, t.workingDir, true)
if err != nil {
return ErrorResult("Command blocked by safety guard (" + err.Error() + ")")
}
cwd = resolvedWD
} else {
cwd = wd
}
}
if cwd == "" {
wd, err := os.Getwd()
if err == nil {
cwd = wd
}
}
if guardError := t.guardCommand(command, cwd); guardError != "" {
return ErrorResult(guardError)
}
// timeout == 0 means no timeout
var cmdCtx context.Context
var cancel context.CancelFunc
if t.timeout > 0 {
cmdCtx, cancel = context.WithTimeout(ctx, t.timeout)
} else {
cmdCtx, cancel = context.WithCancel(ctx)
}
defer cancel()
var cmd *exec.Cmd
if runtime.GOOS == "windows" {
cmd = exec.CommandContext(cmdCtx, "powershell", "-NoProfile", "-NonInteractive", "-Command", command)
} else {
cmd = exec.CommandContext(cmdCtx, "sh", "-c", command)
}
if cwd != "" {
cmd.Dir = cwd
}
prepareCommandForTermination(cmd)
var stdout, stderr bytes.Buffer
cmd.Stdout = &stdout
cmd.Stderr = &stderr
if err := cmd.Start(); err != nil {
return ErrorResult(fmt.Sprintf("failed to start command: %v", err))
}
done := make(chan error, 1)
go func() {
done <- cmd.Wait()
}()
var err error
select {
case err = <-done:
case <-cmdCtx.Done():
_ = terminateProcessTree(cmd)
select {
case err = <-done:
case <-time.After(2 * time.Second):
if cmd.Process != nil {
_ = cmd.Process.Kill()
}
err = <-done
}
}
output := stdout.String()
if stderr.Len() > 0 {
output += "\nSTDERR:\n" + stderr.String()
}
if err != nil {
if errors.Is(cmdCtx.Err(), context.DeadlineExceeded) {
msg := fmt.Sprintf("Command timed out after %v", t.timeout)
return &ToolResult{
ForLLM: msg,
ForUser: msg,
IsError: true,
}
}
output += fmt.Sprintf("\nExit code: %v", err)
}
if output == "" {
output = "(no output)"
}
maxLen := 10000
if len(output) > maxLen {
output = output[:maxLen] + fmt.Sprintf("\n... (truncated, %d more chars)", len(output)-maxLen)
}
if err != nil {
return &ToolResult{
ForLLM: output,
ForUser: output,
IsError: true,
}
}
return &ToolResult{
ForLLM: output,
ForUser: output,
IsError: false,
}
}
func (t *ExecTool) guardCommand(command, cwd string) string {
cmd := strings.TrimSpace(command)
lower := strings.ToLower(cmd)
for _, pattern := range t.denyPatterns {
if pattern.MatchString(lower) {
return "Command blocked by safety guard (dangerous pattern detected)"
}
}
if len(t.allowPatterns) > 0 {
allowed := false
for _, pattern := range t.allowPatterns {
if pattern.MatchString(lower) {
allowed = true
break
}
}
if !allowed {
return "Command blocked by safety guard (not in allowlist)"
}
}
if t.restrictToWorkspace {
if strings.Contains(cmd, "..\\") || strings.Contains(cmd, "../") {
return "Command blocked by safety guard (path traversal detected)"
}
cwdPath, err := filepath.Abs(cwd)
if err != nil {
return ""
}
pathPattern := regexp.MustCompile(`(?:^|\s)([A-Za-z]:\\[^\\"']+|/[a-zA-Z][^\s"']*)`)
matches := pathPattern.FindAllStringSubmatch(cmd, -1)
for _, match := range matches {
raw := match[1]
p, err := filepath.Abs(raw)
if err != nil {
continue
}
rel, err := filepath.Rel(cwdPath, p)
if err != nil {
continue
}
if strings.HasPrefix(rel, "..") {
return "Command blocked by safety guard (path outside working dir)"
}
}
}
return ""
}
func (t *ExecTool) SetTimeout(timeout time.Duration) {
t.timeout = timeout
}
func (t *ExecTool) SetRestrictToWorkspace(restrict bool) {
t.restrictToWorkspace = restrict
}
func (t *ExecTool) SetAllowPatterns(patterns []string) error {
t.allowPatterns = make([]*regexp.Regexp, 0, len(patterns))
for _, p := range patterns {
re, err := regexp.Compile(p)
if err != nil {
return fmt.Errorf("invalid allow pattern %q: %w", p, err)
}
t.allowPatterns = append(t.allowPatterns, re)
}
return nil
}