mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
881999aceb
Replace unconditional WithTimeout usage with conditional context creation based on timeout configuration. Zero values now bypass timeout enforcement, using WithCancel for graceful cancellation while preserving existing timeout behavior for positive values. Simplifies CronTool initialization by removing unnecessary conditional timeout assignment.
233 lines
5.1 KiB
Go
233 lines
5.1 KiB
Go
package tools
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"fmt"
|
|
"os"
|
|
"os/exec"
|
|
"path/filepath"
|
|
"regexp"
|
|
"runtime"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
type ExecTool struct {
|
|
workingDir string
|
|
timeout time.Duration
|
|
denyPatterns []*regexp.Regexp
|
|
allowPatterns []*regexp.Regexp
|
|
restrictToWorkspace bool
|
|
}
|
|
|
|
func NewExecTool(workingDir string, restrict bool) *ExecTool {
|
|
denyPatterns := []*regexp.Regexp{
|
|
regexp.MustCompile(`\brm\s+-[rf]{1,2}\b`),
|
|
regexp.MustCompile(`\bdel\s+/[fq]\b`),
|
|
regexp.MustCompile(`\brmdir\s+/s\b`),
|
|
regexp.MustCompile(`\b(format|mkfs|diskpart)\b\s`), // Match disk wiping commands (must be followed by space/args)
|
|
regexp.MustCompile(`\bdd\s+if=`),
|
|
regexp.MustCompile(`>\s*/dev/sd[a-z]\b`), // Block writes to disk devices (but allow /dev/null)
|
|
regexp.MustCompile(`\b(shutdown|reboot|poweroff)\b`),
|
|
regexp.MustCompile(`:\(\)\s*\{.*\};\s*:`),
|
|
}
|
|
|
|
return &ExecTool{
|
|
workingDir: workingDir,
|
|
timeout: 60 * time.Second,
|
|
denyPatterns: denyPatterns,
|
|
allowPatterns: nil,
|
|
restrictToWorkspace: restrict,
|
|
}
|
|
}
|
|
|
|
func (t *ExecTool) Name() string {
|
|
return "exec"
|
|
}
|
|
|
|
func (t *ExecTool) Description() string {
|
|
return "Execute a shell command and return its output. Use with caution."
|
|
}
|
|
|
|
func (t *ExecTool) Parameters() map[string]interface{} {
|
|
return map[string]interface{}{
|
|
"type": "object",
|
|
"properties": map[string]interface{}{
|
|
"command": map[string]interface{}{
|
|
"type": "string",
|
|
"description": "The shell command to execute",
|
|
},
|
|
"working_dir": map[string]interface{}{
|
|
"type": "string",
|
|
"description": "Optional working directory for the command",
|
|
},
|
|
},
|
|
"required": []string{"command"},
|
|
}
|
|
}
|
|
|
|
func (t *ExecTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
|
|
command, ok := args["command"].(string)
|
|
if !ok {
|
|
return ErrorResult("command is required")
|
|
}
|
|
|
|
cwd := t.workingDir
|
|
if wd, ok := args["working_dir"].(string); ok && wd != "" {
|
|
cwd = wd
|
|
}
|
|
|
|
if cwd == "" {
|
|
wd, err := os.Getwd()
|
|
if err == nil {
|
|
cwd = wd
|
|
}
|
|
}
|
|
|
|
if guardError := t.guardCommand(command, cwd); guardError != "" {
|
|
return ErrorResult(guardError)
|
|
}
|
|
|
|
// timeout == 0 means no timeout
|
|
var cmdCtx context.Context
|
|
var cancel context.CancelFunc
|
|
if t.timeout > 0 {
|
|
cmdCtx, cancel = context.WithTimeout(ctx, t.timeout)
|
|
} else {
|
|
cmdCtx, cancel = context.WithCancel(ctx)
|
|
}
|
|
defer cancel()
|
|
|
|
var cmd *exec.Cmd
|
|
if runtime.GOOS == "windows" {
|
|
cmd = exec.CommandContext(cmdCtx, "powershell", "-NoProfile", "-NonInteractive", "-Command", command)
|
|
} else {
|
|
cmd = exec.CommandContext(cmdCtx, "sh", "-c", command)
|
|
}
|
|
if cwd != "" {
|
|
cmd.Dir = cwd
|
|
}
|
|
|
|
var stdout, stderr bytes.Buffer
|
|
cmd.Stdout = &stdout
|
|
cmd.Stderr = &stderr
|
|
|
|
err := cmd.Run()
|
|
output := stdout.String()
|
|
if stderr.Len() > 0 {
|
|
output += "\nSTDERR:\n" + stderr.String()
|
|
}
|
|
|
|
if err != nil {
|
|
if cmdCtx.Err() == context.DeadlineExceeded {
|
|
msg := fmt.Sprintf("Command timed out after %v", t.timeout)
|
|
return &ToolResult{
|
|
ForLLM: msg,
|
|
ForUser: msg,
|
|
IsError: true,
|
|
}
|
|
}
|
|
output += fmt.Sprintf("\nExit code: %v", err)
|
|
}
|
|
|
|
if output == "" {
|
|
output = "(no output)"
|
|
}
|
|
|
|
maxLen := 10000
|
|
if len(output) > maxLen {
|
|
output = output[:maxLen] + fmt.Sprintf("\n... (truncated, %d more chars)", len(output)-maxLen)
|
|
}
|
|
|
|
if err != nil {
|
|
return &ToolResult{
|
|
ForLLM: output,
|
|
ForUser: output,
|
|
IsError: true,
|
|
}
|
|
}
|
|
|
|
return &ToolResult{
|
|
ForLLM: output,
|
|
ForUser: output,
|
|
IsError: false,
|
|
}
|
|
}
|
|
|
|
func (t *ExecTool) guardCommand(command, cwd string) string {
|
|
cmd := strings.TrimSpace(command)
|
|
lower := strings.ToLower(cmd)
|
|
|
|
for _, pattern := range t.denyPatterns {
|
|
if pattern.MatchString(lower) {
|
|
return "Command blocked by safety guard (dangerous pattern detected)"
|
|
}
|
|
}
|
|
|
|
if len(t.allowPatterns) > 0 {
|
|
allowed := false
|
|
for _, pattern := range t.allowPatterns {
|
|
if pattern.MatchString(lower) {
|
|
allowed = true
|
|
break
|
|
}
|
|
}
|
|
if !allowed {
|
|
return "Command blocked by safety guard (not in allowlist)"
|
|
}
|
|
}
|
|
|
|
if t.restrictToWorkspace {
|
|
if strings.Contains(cmd, "..\\") || strings.Contains(cmd, "../") {
|
|
return "Command blocked by safety guard (path traversal detected)"
|
|
}
|
|
|
|
cwdPath, err := filepath.Abs(cwd)
|
|
if err != nil {
|
|
return ""
|
|
}
|
|
|
|
pathPattern := regexp.MustCompile(`[A-Za-z]:\\[^\\\"']+|/[^\s\"']+`)
|
|
matches := pathPattern.FindAllString(cmd, -1)
|
|
|
|
for _, raw := range matches {
|
|
p, err := filepath.Abs(raw)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
|
|
rel, err := filepath.Rel(cwdPath, p)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
|
|
if strings.HasPrefix(rel, "..") {
|
|
return "Command blocked by safety guard (path outside working dir)"
|
|
}
|
|
}
|
|
}
|
|
|
|
return ""
|
|
}
|
|
|
|
func (t *ExecTool) SetTimeout(timeout time.Duration) {
|
|
t.timeout = timeout
|
|
}
|
|
|
|
func (t *ExecTool) SetRestrictToWorkspace(restrict bool) {
|
|
t.restrictToWorkspace = restrict
|
|
}
|
|
|
|
func (t *ExecTool) SetAllowPatterns(patterns []string) error {
|
|
t.allowPatterns = make([]*regexp.Regexp, 0, len(patterns))
|
|
for _, p := range patterns {
|
|
re, err := regexp.Compile(p)
|
|
if err != nil {
|
|
return fmt.Errorf("invalid allow pattern %q: %w", p, err)
|
|
}
|
|
t.allowPatterns = append(t.allowPatterns, re)
|
|
}
|
|
return nil
|
|
}
|