diff --git a/cmd/picoclaw/internal/onboard/command.go b/cmd/picoclaw/internal/onboard/command.go index bf8f4104f..3f0ff0d8d 100644 --- a/cmd/picoclaw/internal/onboard/command.go +++ b/cmd/picoclaw/internal/onboard/command.go @@ -6,7 +6,7 @@ import ( "github.com/spf13/cobra" ) -//go:generate go run ../../../../scripts/copydir.go ../../../../workspace ./workspace +//go:generate go run ../../../../scripts/copydir.go "${DOLLAR}{codespace}/workspace" ./workspace //go:embed workspace var embeddedFiles embed.FS diff --git a/scripts/copydir.go b/scripts/copydir.go index 74eff6c72..6e2777612 100644 --- a/scripts/copydir.go +++ b/scripts/copydir.go @@ -5,6 +5,8 @@ import ( "io" "os" "path/filepath" + "runtime" + "strings" ) func main() { @@ -13,8 +15,36 @@ func main() { os.Exit(2) } - src := os.Args[1] - dst := os.Args[2] + repoRoot, err := findRepoRoot() + if err != nil { + fmt.Fprintf(os.Stderr, "locate repo root: %v\n", err) + os.Exit(1) + } + + src, err := normalizePathArg(os.Args[1], repoRoot) + if err != nil { + fmt.Fprintf(os.Stderr, "resolve src path: %v\n", err) + os.Exit(1) + } + + dst, err := normalizePathArg(os.Args[2], repoRoot) + if err != nil { + fmt.Fprintf(os.Stderr, "resolve dst path: %v\n", err) + os.Exit(1) + } + + if err := ensurePathWithinRepo(repoRoot, src); err != nil { + fmt.Fprintf(os.Stderr, "invalid src path: %v\n", err) + os.Exit(1) + } + if err := ensurePathWithinRepo(repoRoot, dst); err != nil { + fmt.Fprintf(os.Stderr, "invalid dst path: %v\n", err) + os.Exit(1) + } + if samePath(repoRoot, dst) { + fmt.Fprintln(os.Stderr, "invalid dst path: destination cannot be repo root") + os.Exit(1) + } if err := os.RemoveAll(dst); err != nil { fmt.Fprintf(os.Stderr, "remove %s: %v\n", dst, err) @@ -27,6 +57,78 @@ func main() { } } +func findRepoRoot() (string, error) { + _, file, _, ok := runtime.Caller(0) + if !ok { + return "", fmt.Errorf("unable to locate copydir.go source path") + } + + scriptDir := filepath.Dir(file) + candidate := filepath.Clean(filepath.Join(scriptDir, "..")) + if err := validateRepoRoot(candidate); err == nil { + return candidate, nil + } + + wd, err := os.Getwd() + if err != nil { + return "", err + } + + cur, err := filepath.Abs(wd) + if err != nil { + return "", err + } + + for { + if err := validateRepoRoot(cur); err == nil { + return filepath.Clean(cur), nil + } + parent := filepath.Dir(cur) + if parent == cur { + return "", fmt.Errorf("could not find repository root from %s", wd) + } + cur = parent + } +} + +func validateRepoRoot(root string) error { + anchors := []string{ + filepath.Join(root, "go.sum"), + filepath.Join(root, "LICENSE"), + filepath.Join(root, ".github"), + } + for _, anchor := range anchors { + if _, err := os.Stat(anchor); err != nil { + return fmt.Errorf("missing repo anchor %s: %w", anchor, err) + } + } + return nil +} + +func normalizePathArg(arg, repoRoot string) (string, error) { + resolved := strings.ReplaceAll(arg, "${codespace}", repoRoot) + abs, err := filepath.Abs(resolved) + if err != nil { + return "", err + } + return filepath.Clean(abs), nil +} + +func ensurePathWithinRepo(repoRoot, path string) error { + rel, err := filepath.Rel(repoRoot, path) + if err != nil { + return err + } + if rel == ".." || strings.HasPrefix(rel, ".."+string(filepath.Separator)) { + return fmt.Errorf("path %s is outside repository root %s", path, repoRoot) + } + return nil +} + +func samePath(a, b string) bool { + return filepath.Clean(a) == filepath.Clean(b) +} + func copyTree(src, dst string) error { info, err := os.Stat(src) if err != nil { diff --git a/web/backend/api/exec_nonwindows.go b/web/backend/api/exec_nonwindows.go new file mode 100644 index 000000000..0dc3c0e94 --- /dev/null +++ b/web/backend/api/exec_nonwindows.go @@ -0,0 +1,11 @@ +//go:build !windows + +package api + +import "os/exec" + +func launcherExecCommand(name string, args ...string) *exec.Cmd { + return exec.Command(name, args...) +} + +func applyLauncherProcAttrs(_ *exec.Cmd) {} diff --git a/web/backend/api/exec_windows.go b/web/backend/api/exec_windows.go new file mode 100644 index 000000000..86d3193a0 --- /dev/null +++ b/web/backend/api/exec_windows.go @@ -0,0 +1,24 @@ +//go:build windows + +package api + +import ( + "os/exec" + "syscall" +) + +func launcherExecCommand(name string, args ...string) *exec.Cmd { + cmd := exec.Command(name, args...) + applyLauncherProcAttrs(cmd) + return cmd +} + +func applyLauncherProcAttrs(cmd *exec.Cmd) { + if cmd == nil { + return + } + if cmd.SysProcAttr == nil { + cmd.SysProcAttr = &syscall.SysProcAttr{} + } + cmd.SysProcAttr.HideWindow = true +} diff --git a/web/backend/api/gateway.go b/web/backend/api/gateway.go index 201000ff3..606c8351d 100644 --- a/web/backend/api/gateway.go +++ b/web/backend/api/gateway.go @@ -164,7 +164,7 @@ func isLikelyGatewayProcess(pid int) (bool, bool) { `$p=Get-CimInstance Win32_Process -Filter "ProcessId = %d"; if ($null -eq $p) { "" } else { $p.CommandLine }`, pid, ) - out, err := exec.Command("powershell", "-NoProfile", "-NonInteractive", "-Command", psCmd).Output() + out, err := launcherExecCommand("powershell", "-NoProfile", "-NonInteractive", "-Command", psCmd).Output() if err == nil { cmdline := strings.TrimSpace(string(out)) if cmdline != "" { @@ -173,7 +173,7 @@ func isLikelyGatewayProcess(pid int) (bool, bool) { } // Fallback: determine only whether the process still exists. - out, err = exec.Command("tasklist", "/FI", "PID eq "+strconv.Itoa(pid), "/FO", "CSV", "/NH").Output() + out, err = launcherExecCommand("tasklist", "/FI", "PID eq "+strconv.Itoa(pid), "/FO", "CSV", "/NH").Output() if err != nil { return false, false } @@ -187,7 +187,7 @@ func isLikelyGatewayProcess(pid int) (bool, bool) { if strings.Contains(line, "\"picoclaw.exe\"") { return true, true } - return false, false + return false, true } if strings.Contains(line, "no tasks are running") { return false, true @@ -195,7 +195,7 @@ func isLikelyGatewayProcess(pid int) (bool, bool) { return false, true } - out, err := exec.Command("ps", "-o", "command=", "-p", strconv.Itoa(pid)).Output() + out, err := launcherExecCommand("ps", "-o", "command=", "-p", strconv.Itoa(pid)).Output() if err != nil { return false, false } @@ -706,6 +706,7 @@ func (h *Handler) startGatewayLocked(initialStatus string, existingPid int) (int logger.InfoC("gateway", fmt.Sprintf("Starting gateway process (%s)", execPath)) cmd = gatewayExecCommand(execPath, h.gatewayCommandArgs()...) + applyLauncherProcAttrs(cmd) cmd.Env = os.Environ() // Forward the launcher's config path via the environment variable that // GetConfigPath() already reads, so the gateway sub-process uses the same