mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
fix(exec): terminate process tree on timeout
This commit is contained in:
+28
-2
@@ -3,6 +3,7 @@ package tools
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
@@ -109,18 +110,43 @@ func (t *ExecTool) Execute(ctx context.Context, args map[string]interface{}) *To
|
||||
cmd.Dir = cwd
|
||||
}
|
||||
|
||||
prepareCommandForTermination(cmd)
|
||||
|
||||
var stdout, stderr bytes.Buffer
|
||||
cmd.Stdout = &stdout
|
||||
cmd.Stderr = &stderr
|
||||
|
||||
err := cmd.Run()
|
||||
if err := cmd.Start(); err != nil {
|
||||
return ErrorResult(fmt.Sprintf("failed to start command: %v", err))
|
||||
}
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- cmd.Wait()
|
||||
}()
|
||||
|
||||
var err error
|
||||
select {
|
||||
case err = <-done:
|
||||
case <-cmdCtx.Done():
|
||||
_ = terminateProcessTree(cmd)
|
||||
select {
|
||||
case err = <-done:
|
||||
case <-time.After(2 * time.Second):
|
||||
if cmd.Process != nil {
|
||||
_ = cmd.Process.Kill()
|
||||
}
|
||||
err = <-done
|
||||
}
|
||||
}
|
||||
|
||||
output := stdout.String()
|
||||
if stderr.Len() > 0 {
|
||||
output += "\nSTDERR:\n" + stderr.String()
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if cmdCtx.Err() == context.DeadlineExceeded {
|
||||
if errors.Is(cmdCtx.Err(), context.DeadlineExceeded) {
|
||||
msg := fmt.Sprintf("Command timed out after %v", t.timeout)
|
||||
return &ToolResult{
|
||||
ForLLM: msg,
|
||||
|
||||
@@ -0,0 +1,32 @@
|
||||
//go:build !windows
|
||||
|
||||
package tools
|
||||
|
||||
import (
|
||||
"os/exec"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
func prepareCommandForTermination(cmd *exec.Cmd) {
|
||||
if cmd == nil {
|
||||
return
|
||||
}
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
|
||||
}
|
||||
|
||||
func terminateProcessTree(cmd *exec.Cmd) error {
|
||||
if cmd == nil || cmd.Process == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
pid := cmd.Process.Pid
|
||||
if pid <= 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Kill the entire process group spawned by the shell command.
|
||||
_ = syscall.Kill(-pid, syscall.SIGKILL)
|
||||
// Fallback kill on the shell process itself.
|
||||
_ = cmd.Process.Kill()
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,27 @@
|
||||
//go:build windows
|
||||
|
||||
package tools
|
||||
|
||||
import (
|
||||
"os/exec"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
func prepareCommandForTermination(cmd *exec.Cmd) {
|
||||
// no-op on Windows
|
||||
}
|
||||
|
||||
func terminateProcessTree(cmd *exec.Cmd) error {
|
||||
if cmd == nil || cmd.Process == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
pid := cmd.Process.Pid
|
||||
if pid <= 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
_ = exec.Command("taskkill", "/T", "/F", "/PID", strconv.Itoa(pid)).Run()
|
||||
_ = cmd.Process.Kill()
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,61 @@
|
||||
//go:build !windows
|
||||
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func processExists(pid int) bool {
|
||||
if pid <= 0 {
|
||||
return false
|
||||
}
|
||||
err := syscall.Kill(pid, 0)
|
||||
return err == nil || err == syscall.EPERM
|
||||
}
|
||||
|
||||
func TestShellTool_TimeoutKillsChildProcess(t *testing.T) {
|
||||
tool := NewExecTool(t.TempDir(), false)
|
||||
tool.SetTimeout(500 * time.Millisecond)
|
||||
|
||||
args := map[string]interface{}{
|
||||
// Spawn a child process that would outlive the shell unless process-group kill is used.
|
||||
"command": "sleep 60 & echo $! > child.pid; wait",
|
||||
}
|
||||
|
||||
result := tool.Execute(context.Background(), args)
|
||||
if !result.IsError {
|
||||
t.Fatalf("expected timeout error, got success: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "timed out") {
|
||||
t.Fatalf("expected timeout message, got: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
childPIDPath := filepath.Join(tool.workingDir, "child.pid")
|
||||
data, err := os.ReadFile(childPIDPath)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read child pid file: %v", err)
|
||||
}
|
||||
|
||||
childPID, err := strconv.Atoi(strings.TrimSpace(string(data)))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse child pid: %v", err)
|
||||
}
|
||||
|
||||
deadline := time.Now().Add(2 * time.Second)
|
||||
for time.Now().Before(deadline) {
|
||||
if !processExists(childPID) {
|
||||
return
|
||||
}
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
}
|
||||
|
||||
t.Fatalf("child process %d is still running after timeout", childPID)
|
||||
}
|
||||
Reference in New Issue
Block a user