mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
merge main
This commit is contained in:
@@ -236,9 +236,8 @@ func TestNewAgentInstance_AllowsMediaTempDirForReadListAndExec(t *testing.T) {
|
||||
t.Fatal("exec tool not registered")
|
||||
}
|
||||
execResult := execTool.Execute(context.Background(), map[string]any{
|
||||
"action": "run",
|
||||
"command": "cat " + filepath.Base(mediaPath),
|
||||
"cwd": mediaDir,
|
||||
"command": "cat " + filepath.Base(mediaPath),
|
||||
"working_dir": mediaDir,
|
||||
})
|
||||
if execResult.IsError {
|
||||
t.Fatalf("exec should allow media temp dir, got: %s", execResult.ForLLM)
|
||||
|
||||
@@ -253,6 +253,7 @@ type AgentDefaults struct {
|
||||
SteeringMode string `json:"steering_mode,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_STEERING_MODE"` // "one-at-a-time" (default) or "all"
|
||||
SubTurn SubTurnConfig `json:"subturn" envPrefix:"PICOCLAW_AGENTS_DEFAULTS_SUBTURN_"`
|
||||
ToolFeedback ToolFeedbackConfig `json:"tool_feedback,omitempty"`
|
||||
LogLevel string `json:"log_level,omitempty" env:"PICOCLAW_LOG_LEVEL"`
|
||||
}
|
||||
|
||||
const (
|
||||
|
||||
@@ -470,6 +470,13 @@ func TestDefaultConfig_CronAllowCommandEnabled(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultConfig_LogLevel(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
if cfg.Agents.Defaults.LogLevel != "fatal" {
|
||||
t.Errorf("LogLevel = %q, want \"fatal\"", cfg.Agents.Defaults.LogLevel)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfig_OpenAIWebSearchDefaultsTrueWhenUnset(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
configPath := filepath.Join(dir, "config.json")
|
||||
@@ -1057,3 +1064,38 @@ func TestLoadConfig_UsesPassphraseProvider(t *testing.T) {
|
||||
t.Errorf("api_key = %q, want %q", cfg.ModelList[0].APIKey, plainKey)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigParsesLogLevel(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
cfgPath := filepath.Join(dir, "config.json")
|
||||
data := `{"agents":{"defaults":{"log_level":"debug"}}}`
|
||||
if err := os.WriteFile(cfgPath, []byte(data), 0o600); err != nil {
|
||||
t.Fatalf("setup: %v", err)
|
||||
}
|
||||
|
||||
cfg, err := LoadConfig(cfgPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig: %v", err)
|
||||
}
|
||||
if cfg.Agents.Defaults.LogLevel != "debug" {
|
||||
t.Errorf("LogLevel = %q, want \"debug\"", cfg.Agents.Defaults.LogLevel)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigLogLevelEmpty(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
cfgPath := filepath.Join(dir, "config.json")
|
||||
data := `{}`
|
||||
if err := os.WriteFile(cfgPath, []byte(data), 0o600); err != nil {
|
||||
t.Fatalf("setup: %v", err)
|
||||
}
|
||||
|
||||
cfg, err := LoadConfig(cfgPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig: %v", err)
|
||||
}
|
||||
// When config omits log_level, the DefaultConfig value ("fatal") is preserved.
|
||||
if cfg.Agents.Defaults.LogLevel != "fatal" {
|
||||
t.Errorf("LogLevel = %q, want \"fatal\"", cfg.Agents.Defaults.LogLevel)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -26,6 +26,7 @@ func DefaultConfig() *Config {
|
||||
return &Config{
|
||||
Agents: AgentsConfig{
|
||||
Defaults: AgentDefaults{
|
||||
LogLevel: "fatal",
|
||||
Workspace: workspacePath,
|
||||
RestrictToWorkspace: true,
|
||||
Provider: "",
|
||||
|
||||
@@ -79,16 +79,18 @@ func (p *startupBlockedProvider) GetDefaultModel() string {
|
||||
|
||||
// Run starts the gateway runtime using the configuration loaded from configPath.
|
||||
func Run(debug bool, configPath string, allowEmptyStartup bool) error {
|
||||
if debug {
|
||||
logger.SetLevel(logger.DEBUG)
|
||||
fmt.Println("🔍 Debug mode enabled")
|
||||
}
|
||||
|
||||
cfg, err := config.LoadConfig(configPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error loading config: %w", err)
|
||||
}
|
||||
|
||||
logger.SetLevelFromString(cfg.Agents.Defaults.LogLevel)
|
||||
|
||||
if debug {
|
||||
logger.SetLevel(logger.DEBUG)
|
||||
fmt.Println("🔍 Debug mode enabled")
|
||||
}
|
||||
|
||||
provider, modelID, err := createStartupProvider(cfg, allowEmptyStartup)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating provider: %w", err)
|
||||
|
||||
@@ -26,6 +26,7 @@ import (
|
||||
const (
|
||||
minIntervalMinutes = 5
|
||||
defaultIntervalMinutes = 30
|
||||
userTasksMarker = "Add your heartbeat tasks below this line:"
|
||||
)
|
||||
|
||||
// HeartbeatHandler is the function type for handling heartbeat.
|
||||
@@ -232,7 +233,7 @@ func (hs *HeartbeatService) buildPrompt() string {
|
||||
}
|
||||
|
||||
content := string(data)
|
||||
if len(content) == 0 {
|
||||
if !heartbeatHasUserTasks(content) {
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -284,6 +285,32 @@ Add your heartbeat tasks below this line:
|
||||
}
|
||||
}
|
||||
|
||||
func heartbeatHasUserTasks(content string) bool {
|
||||
trimmed := strings.TrimSpace(content)
|
||||
if trimmed == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
markerIdx := strings.Index(content, userTasksMarker)
|
||||
if markerIdx < 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
tasksSection := content[markerIdx+len(userTasksMarker):]
|
||||
for _, line := range strings.Split(tasksSection, "\n") {
|
||||
trimmedLine := strings.TrimSpace(line)
|
||||
if trimmedLine == "" {
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(trimmedLine, "#") {
|
||||
continue
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// sendResponse sends the heartbeat response to the last channel
|
||||
func (hs *HeartbeatService) sendResponse(response string) {
|
||||
hs.mu.RLock()
|
||||
|
||||
@@ -3,6 +3,7 @@ package heartbeat
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -203,3 +204,47 @@ func TestHeartbeatFilePath(t *testing.T) {
|
||||
t.Errorf("Expected HEARTBEAT.md at %s, but it doesn't exist", expectedPath)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildPrompt_DefaultTemplateStaysIdle(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "heartbeat-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
hs := NewHeartbeatService(tmpDir, 30, true)
|
||||
hs.createDefaultHeartbeatTemplate()
|
||||
|
||||
if prompt := hs.buildPrompt(); prompt != "" {
|
||||
t.Fatalf("buildPrompt() = %q, want empty prompt for untouched default template", prompt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildPrompt_UserTasksAfterMarkerProducePrompt(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "heartbeat-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
hs := NewHeartbeatService(tmpDir, 30, true)
|
||||
hs.createDefaultHeartbeatTemplate()
|
||||
|
||||
path := filepath.Join(tmpDir, "HEARTBEAT.md")
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read HEARTBEAT.md: %v", err)
|
||||
}
|
||||
updated := string(data) + "\n- Check unread Feishu messages\n"
|
||||
if err := os.WriteFile(path, []byte(updated), 0o644); err != nil {
|
||||
t.Fatalf("Failed to update HEARTBEAT.md: %v", err)
|
||||
}
|
||||
|
||||
prompt := hs.buildPrompt()
|
||||
if prompt == "" {
|
||||
t.Fatal("buildPrompt() = empty, want non-empty prompt when user tasks are present")
|
||||
}
|
||||
if !strings.Contains(prompt, "Check unread Feishu messages") {
|
||||
t.Fatalf("prompt = %q, want user task content", prompt)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -94,13 +94,18 @@ func MatchAllowed(sender bus.SenderInfo, allowed string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// isNumeric returns true if s consists entirely of digits.
|
||||
// isNumeric returns true if s consists entirely of digits, allowing for an optional leading minus sign
|
||||
// (required for Telegram group/channel IDs like -1001234567890).
|
||||
func isNumeric(s string) bool {
|
||||
if s == "" {
|
||||
return false
|
||||
}
|
||||
for _, r := range s {
|
||||
if r < '0' || r > '9' {
|
||||
start := 0
|
||||
if s[0] == '-' && len(s) > 1 {
|
||||
start = 1
|
||||
}
|
||||
for i := start; i < len(s); i++ {
|
||||
if s[i] < '0' || s[i] > '9' {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
@@ -97,6 +97,15 @@ func TestMatchAllowed(t *testing.T) {
|
||||
allowed: "654321",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "negative numeric ID matches PlatformID",
|
||||
sender: bus.SenderInfo{
|
||||
Platform: "telegram",
|
||||
PlatformID: "-1001234567890",
|
||||
},
|
||||
allowed: "-1001234567890",
|
||||
want: true,
|
||||
},
|
||||
// Username matching
|
||||
{
|
||||
name: "@username matches Username",
|
||||
@@ -238,6 +247,9 @@ func TestIsNumeric(t *testing.T) {
|
||||
{"abc", false},
|
||||
{"12a34", false},
|
||||
{"telegram", false},
|
||||
{"-1001234567890", true},
|
||||
{"-", false},
|
||||
{"-12a34", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
||||
@@ -106,6 +106,36 @@ func GetLevel() LogLevel {
|
||||
return currentLevel
|
||||
}
|
||||
|
||||
// ParseLevel converts a case-insensitive level name to a LogLevel.
|
||||
// Returns the level and true if valid, or (INFO, false) if unrecognized.
|
||||
func ParseLevel(s string) (LogLevel, bool) {
|
||||
switch strings.ToLower(strings.TrimSpace(s)) {
|
||||
case "debug":
|
||||
return DEBUG, true
|
||||
case "info":
|
||||
return INFO, true
|
||||
case "warn", "warning":
|
||||
return WARN, true
|
||||
case "error":
|
||||
return ERROR, true
|
||||
case "fatal":
|
||||
return FATAL, true
|
||||
default:
|
||||
return INFO, false
|
||||
}
|
||||
}
|
||||
|
||||
// SetLevelFromString sets the log level from a string value.
|
||||
// If the string is empty or not a recognized level name, the current level is kept.
|
||||
func SetLevelFromString(s string) {
|
||||
if s == "" {
|
||||
return
|
||||
}
|
||||
if level, ok := ParseLevel(s); ok {
|
||||
SetLevel(level)
|
||||
}
|
||||
}
|
||||
|
||||
func EnableFileLogging(filePath string) error {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
@@ -252,3 +252,88 @@ func TestFormatFieldValue(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultLevelIsInfo(t *testing.T) {
|
||||
// The package-level default (before any SetLevel call) should be INFO.
|
||||
// Because earlier tests may have changed it, we just verify the constant is wired correctly.
|
||||
if logLevelNames[INFO] != "INFO" {
|
||||
t.Errorf("INFO constant mapped to %q, want \"INFO\"", logLevelNames[INFO])
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseLevelValid(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
want LogLevel
|
||||
}{
|
||||
{"debug", DEBUG},
|
||||
{"DEBUG", DEBUG},
|
||||
{"Debug", DEBUG},
|
||||
{"info", INFO},
|
||||
{"INFO", INFO},
|
||||
{"warn", WARN},
|
||||
{"WARN", WARN},
|
||||
{"warning", WARN},
|
||||
{"WARNING", WARN},
|
||||
{"error", ERROR},
|
||||
{"ERROR", ERROR},
|
||||
{"fatal", FATAL},
|
||||
{"FATAL", FATAL},
|
||||
{" info ", INFO},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.input, func(t *testing.T) {
|
||||
got, ok := ParseLevel(tt.input)
|
||||
if !ok {
|
||||
t.Fatalf("ParseLevel(%q) returned ok=false, want true", tt.input)
|
||||
}
|
||||
if got != tt.want {
|
||||
t.Errorf("ParseLevel(%q) = %v, want %v", tt.input, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseLevelInvalid(t *testing.T) {
|
||||
tests := []string{"", "garbage", "verbose", "trace", "critical"}
|
||||
|
||||
for _, input := range tests {
|
||||
t.Run(input, func(t *testing.T) {
|
||||
_, ok := ParseLevel(input)
|
||||
if ok {
|
||||
t.Errorf("ParseLevel(%q) returned ok=true, want false", input)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetLevelFromString(t *testing.T) {
|
||||
initialLevel := GetLevel()
|
||||
defer SetLevel(initialLevel)
|
||||
|
||||
// Valid string changes the level
|
||||
SetLevel(INFO)
|
||||
SetLevelFromString("error")
|
||||
if got := GetLevel(); got != ERROR {
|
||||
t.Errorf("after SetLevelFromString(\"error\"): GetLevel() = %v, want ERROR", got)
|
||||
}
|
||||
|
||||
// Empty string is a no-op
|
||||
SetLevelFromString("")
|
||||
if got := GetLevel(); got != ERROR {
|
||||
t.Errorf("after SetLevelFromString(\"\"): GetLevel() = %v, want ERROR (unchanged)", got)
|
||||
}
|
||||
|
||||
// Invalid string is a no-op
|
||||
SetLevelFromString("garbage")
|
||||
if got := GetLevel(); got != ERROR {
|
||||
t.Errorf("after SetLevelFromString(\"garbage\"): GetLevel() = %v, want ERROR (unchanged)", got)
|
||||
}
|
||||
|
||||
// Case-insensitive
|
||||
SetLevelFromString("FATAL")
|
||||
if got := GetLevel(); got != FATAL {
|
||||
t.Errorf("after SetLevelFromString(\"FATAL\"): GetLevel() = %v, want FATAL", got)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,252 +0,0 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"io"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const maxOutputBufferSize = 100 * 1024 * 1024 // 100MB
|
||||
|
||||
const outputTruncateMarker = "\n... [output truncated, exceeded 100MB]\n"
|
||||
|
||||
// PtyKeyMode represents arrow key encoding mode for PTY sessions.
|
||||
// Programs send smkx/rmkx sequences to switch between CSI and SS3 modes.
|
||||
type PtyKeyMode uint8
|
||||
|
||||
const (
|
||||
PtyKeyModeCSI PtyKeyMode = iota // triggered by rmkx (\x1b[?1l)
|
||||
PtyKeyModeSS3 // triggered by smkx (\x1b[?1h)
|
||||
)
|
||||
|
||||
const PtyKeyModeNotFound PtyKeyMode = 255
|
||||
|
||||
var (
|
||||
ErrSessionNotFound = errors.New("session not found")
|
||||
ErrSessionDone = errors.New("session already completed")
|
||||
ErrPTYNotSupported = errors.New("PTY is not supported on this platform")
|
||||
ErrNoStdin = errors.New("no stdin available")
|
||||
)
|
||||
|
||||
type ProcessSession struct {
|
||||
mu sync.Mutex
|
||||
ID string
|
||||
PID int
|
||||
Command string
|
||||
PTY bool
|
||||
Background bool
|
||||
StartTime int64
|
||||
ExitCode int
|
||||
Status string
|
||||
stdinWriter io.Writer
|
||||
stdoutPipe io.Reader
|
||||
outputBuffer *bytes.Buffer
|
||||
outputTruncated bool
|
||||
ptyMaster *os.File
|
||||
|
||||
// ptyKeyMode tracks arrow key encoding mode (CSI vs SS3)
|
||||
ptyKeyMode PtyKeyMode
|
||||
}
|
||||
|
||||
func (s *ProcessSession) IsDone() bool {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
return s.Status == "done" || s.Status == "exited"
|
||||
}
|
||||
|
||||
func (s *ProcessSession) GetPtyKeyMode() PtyKeyMode {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
return s.ptyKeyMode
|
||||
}
|
||||
|
||||
func (s *ProcessSession) SetPtyKeyMode(mode PtyKeyMode) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.ptyKeyMode = mode
|
||||
}
|
||||
|
||||
func (s *ProcessSession) GetStatus() string {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
return s.Status
|
||||
}
|
||||
|
||||
func (s *ProcessSession) SetStatus(status string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.Status = status
|
||||
}
|
||||
|
||||
func (s *ProcessSession) GetExitCode() int {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
return s.ExitCode
|
||||
}
|
||||
|
||||
func (s *ProcessSession) SetExitCode(code int) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.ExitCode = code
|
||||
}
|
||||
|
||||
func (s *ProcessSession) killProcess() error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.Status != "running" {
|
||||
return ErrSessionDone
|
||||
}
|
||||
|
||||
pid := s.PID
|
||||
if pid <= 0 {
|
||||
return ErrSessionNotFound
|
||||
}
|
||||
|
||||
if err := killProcessGroup(pid); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.Status = "done"
|
||||
s.ExitCode = -1
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ProcessSession) Kill() error {
|
||||
return s.killProcess()
|
||||
}
|
||||
|
||||
func (s *ProcessSession) Write(data string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.Status != "running" {
|
||||
return ErrSessionDone
|
||||
}
|
||||
|
||||
var writer io.Writer
|
||||
if s.PTY && s.ptyMaster != nil {
|
||||
writer = s.ptyMaster
|
||||
} else if s.stdinWriter != nil {
|
||||
writer = s.stdinWriter
|
||||
} else {
|
||||
return ErrNoStdin
|
||||
}
|
||||
|
||||
_, err := writer.Write([]byte(data))
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *ProcessSession) Read() string {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.outputBuffer.Len() == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
data := s.outputBuffer.String()
|
||||
s.outputBuffer.Reset()
|
||||
return data
|
||||
}
|
||||
|
||||
func (s *ProcessSession) ToSessionInfo() SessionInfo {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
return SessionInfo{
|
||||
ID: s.ID,
|
||||
Command: s.Command,
|
||||
Status: s.Status,
|
||||
PID: s.PID,
|
||||
StartedAt: s.StartTime,
|
||||
}
|
||||
}
|
||||
|
||||
type SessionManager struct {
|
||||
mu sync.RWMutex
|
||||
sessions map[string]*ProcessSession
|
||||
}
|
||||
|
||||
func NewSessionManager() *SessionManager {
|
||||
sm := &SessionManager{
|
||||
sessions: make(map[string]*ProcessSession),
|
||||
}
|
||||
|
||||
// Start cleaner goroutine - runs every 5 minutes, cleans up sessions done for >30 minutes
|
||||
go func() {
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
for range ticker.C {
|
||||
sm.cleanupOldSessions()
|
||||
}
|
||||
}()
|
||||
|
||||
return sm
|
||||
}
|
||||
|
||||
// cleanupOldSessions removes sessions that are done and older than 30 minutes
|
||||
func (sm *SessionManager) cleanupOldSessions() {
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
|
||||
cutoff := time.Now().Add(-30 * time.Minute)
|
||||
for id, session := range sm.sessions {
|
||||
if session.IsDone() && session.StartTime < cutoff.Unix() {
|
||||
delete(sm.sessions, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (sm *SessionManager) Add(session *ProcessSession) {
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
sm.sessions[session.ID] = session
|
||||
}
|
||||
|
||||
func (sm *SessionManager) Get(sessionID string) (*ProcessSession, error) {
|
||||
sm.mu.RLock()
|
||||
defer sm.mu.RUnlock()
|
||||
|
||||
session, ok := sm.sessions[sessionID]
|
||||
if !ok {
|
||||
return nil, ErrSessionNotFound
|
||||
}
|
||||
|
||||
return session, nil
|
||||
}
|
||||
|
||||
func (sm *SessionManager) Remove(sessionID string) {
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
delete(sm.sessions, sessionID)
|
||||
}
|
||||
|
||||
func (sm *SessionManager) List() []SessionInfo {
|
||||
sm.mu.RLock()
|
||||
defer sm.mu.RUnlock()
|
||||
|
||||
result := make([]SessionInfo, 0, len(sm.sessions))
|
||||
for _, session := range sm.sessions {
|
||||
result = append(result, session.ToSessionInfo())
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func generateSessionID() string {
|
||||
return uuid.New().String()[:8]
|
||||
}
|
||||
|
||||
type SessionInfo struct {
|
||||
ID string `json:"id"`
|
||||
Command string `json:"command"`
|
||||
Status string `json:"status"`
|
||||
PID int `json:"pid"`
|
||||
StartedAt int64 `json:"startedAt"`
|
||||
}
|
||||
@@ -1,14 +0,0 @@
|
||||
//go:build !windows
|
||||
|
||||
package tools
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
)
|
||||
|
||||
func killProcessGroup(pid int) error {
|
||||
if err := syscall.Kill(-pid, syscall.SIGKILL); err != nil {
|
||||
_ = syscall.Kill(pid, syscall.SIGKILL)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1,13 +0,0 @@
|
||||
//go:build windows
|
||||
|
||||
package tools
|
||||
|
||||
import (
|
||||
"os/exec"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
func killProcessGroup(pid int) error {
|
||||
_ = exec.Command("taskkill", "/T", "/F", "/PID", strconv.Itoa(pid)).Run()
|
||||
return nil
|
||||
}
|
||||
@@ -1,99 +0,0 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSessionManager_AddGet(t *testing.T) {
|
||||
sm := NewSessionManager()
|
||||
session := &ProcessSession{
|
||||
ID: "test-1",
|
||||
Command: "echo hello",
|
||||
Status: "running",
|
||||
StartTime: 1000,
|
||||
}
|
||||
|
||||
sm.Add(session)
|
||||
|
||||
got, err := sm.Get("test-1")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "test-1", got.ID)
|
||||
}
|
||||
|
||||
func TestSessionManager_Remove(t *testing.T) {
|
||||
sm := NewSessionManager()
|
||||
session := &ProcessSession{
|
||||
ID: "test-1",
|
||||
Command: "echo hello",
|
||||
Status: "running",
|
||||
StartTime: 1000,
|
||||
}
|
||||
sm.Add(session)
|
||||
sm.Remove("test-1")
|
||||
|
||||
_, err := sm.Get("test-1")
|
||||
require.ErrorIs(t, err, ErrSessionNotFound)
|
||||
}
|
||||
|
||||
func TestSessionManager_List(t *testing.T) {
|
||||
sm := NewSessionManager()
|
||||
sm.Add(&ProcessSession{
|
||||
ID: "test-1",
|
||||
Command: "echo hello",
|
||||
Status: "running",
|
||||
StartTime: 1000,
|
||||
})
|
||||
sm.Add(&ProcessSession{
|
||||
ID: "test-2",
|
||||
Command: "echo world",
|
||||
Status: "running",
|
||||
StartTime: 1001,
|
||||
})
|
||||
sm.Add(&ProcessSession{
|
||||
ID: "test-3",
|
||||
Command: "echo done",
|
||||
Status: "done",
|
||||
StartTime: 1002,
|
||||
})
|
||||
|
||||
sessions := sm.List()
|
||||
require.Len(t, sessions, 3)
|
||||
|
||||
ids := make(map[string]bool)
|
||||
for _, s := range sessions {
|
||||
ids[s.ID] = true
|
||||
}
|
||||
require.True(t, ids["test-1"])
|
||||
require.True(t, ids["test-2"])
|
||||
require.True(t, ids["test-3"])
|
||||
}
|
||||
|
||||
func TestProcessSession_IsDone(t *testing.T) {
|
||||
session := &ProcessSession{Status: "running"}
|
||||
require.False(t, session.IsDone())
|
||||
|
||||
session.Status = "done"
|
||||
require.True(t, session.IsDone())
|
||||
|
||||
session.Status = "exited"
|
||||
require.True(t, session.IsDone())
|
||||
}
|
||||
|
||||
func TestProcessSession_ToSessionInfo(t *testing.T) {
|
||||
session := &ProcessSession{
|
||||
ID: "test-1",
|
||||
PID: 12345,
|
||||
Command: "echo hello",
|
||||
Status: "running",
|
||||
StartTime: 1000,
|
||||
}
|
||||
|
||||
info := session.ToSessionInfo()
|
||||
require.Equal(t, "test-1", info.ID)
|
||||
require.Equal(t, "echo hello", info.Command)
|
||||
require.Equal(t, "running", info.Status)
|
||||
require.Equal(t, 12345, info.PID)
|
||||
require.Equal(t, int64(1000), info.StartedAt)
|
||||
}
|
||||
+12
-730
@@ -3,37 +3,20 @@ package tools
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/creack/pty"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/constants"
|
||||
)
|
||||
|
||||
var (
|
||||
globalSessionManager = NewSessionManager()
|
||||
sessionManagerMu sync.RWMutex
|
||||
)
|
||||
|
||||
func getSessionManager() *SessionManager {
|
||||
sessionManagerMu.RLock()
|
||||
defer sessionManagerMu.RUnlock()
|
||||
return globalSessionManager
|
||||
}
|
||||
|
||||
type ExecTool struct {
|
||||
workingDir string
|
||||
timeout time.Duration
|
||||
@@ -43,7 +26,6 @@ type ExecTool struct {
|
||||
allowedPathPatterns []*regexp.Regexp
|
||||
restrictToWorkspace bool
|
||||
allowRemote bool
|
||||
sessionManager *SessionManager
|
||||
}
|
||||
|
||||
var (
|
||||
@@ -163,7 +145,7 @@ func NewExecToolWithConfig(
|
||||
denyPatterns = append(denyPatterns, defaultDenyPatterns...)
|
||||
}
|
||||
|
||||
var timeout time.Duration
|
||||
timeout := 60 * time.Second
|
||||
if config != nil && config.Tools.Exec.TimeoutSeconds > 0 {
|
||||
timeout = time.Duration(config.Tools.Exec.TimeoutSeconds) * time.Second
|
||||
}
|
||||
@@ -177,7 +159,6 @@ func NewExecToolWithConfig(
|
||||
allowedPathPatterns: allowedPathPatterns,
|
||||
restrictToWorkspace: restrict,
|
||||
allowRemote: allowRemote,
|
||||
sessionManager: getSessionManager(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -186,146 +167,27 @@ func (t *ExecTool) Name() string {
|
||||
}
|
||||
|
||||
func (t *ExecTool) Description() string {
|
||||
return `Execute shell commands. Use background=true for long-running commands (returns sessionId). Use pty=true for interactive commands (can combine with background=true). Use poll/read/write/send-keys/kill with sessionId to manage background sessions. Sessions auto-cleanup 30 minutes after process exits; use kill to terminate early. Output buffer limit: 100MB.`
|
||||
return "Execute a shell command and return its output. Use with caution."
|
||||
}
|
||||
|
||||
func (t *ExecTool) Parameters() map[string]any {
|
||||
return map[string]any{
|
||||
"oneOf": []map[string]any{
|
||||
{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"action": map[string]any{"const": "run", "description": "Execute a shell command"},
|
||||
"command": map[string]any{"type": "string", "description": "Shell command to execute"},
|
||||
"background": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Run in background immediately",
|
||||
},
|
||||
"pty": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Run in a pseudo-terminal (PTY) when available",
|
||||
},
|
||||
"cwd": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Working directory for the command",
|
||||
},
|
||||
"timeout": map[string]any{
|
||||
"type": "integer",
|
||||
"description": "Timeout in seconds (default: 0 = no timeout, kills process on expiry)",
|
||||
},
|
||||
},
|
||||
"required": []string{"action", "command"},
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"command": map[string]any{
|
||||
"type": "string",
|
||||
"description": "The shell command to execute",
|
||||
},
|
||||
{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"action": map[string]any{"const": "list", "description": "List all active sessions"},
|
||||
},
|
||||
"required": []string{"action"},
|
||||
},
|
||||
{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"action": map[string]any{
|
||||
"const": "poll",
|
||||
"description": "Check session status. Returns: {sessionId, status: running|done, exitCode}. exitCode only meaningful when status=done",
|
||||
},
|
||||
"sessionId": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Session ID returned from background command",
|
||||
},
|
||||
},
|
||||
"required": []string{"action", "sessionId"},
|
||||
},
|
||||
{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"action": map[string]any{
|
||||
"const": "read",
|
||||
"description": "Read output from session. Returns: {sessionId, output, status: running|done}",
|
||||
},
|
||||
"sessionId": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Session ID returned from background command",
|
||||
},
|
||||
},
|
||||
"required": []string{"action", "sessionId"},
|
||||
},
|
||||
{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"action": map[string]any{
|
||||
"const": "write",
|
||||
"description": "Send input to session stdin (only when status=running)",
|
||||
},
|
||||
"sessionId": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Session ID returned from background command",
|
||||
},
|
||||
"data": map[string]any{"type": "string", "description": "Data to write to session stdin."},
|
||||
},
|
||||
"required": []string{"action", "sessionId", "data"},
|
||||
},
|
||||
{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"action": map[string]any{"const": "kill", "description": "Terminate session"},
|
||||
"sessionId": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Session ID returned from background command",
|
||||
},
|
||||
},
|
||||
"required": []string{"action", "sessionId"},
|
||||
},
|
||||
{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"action": map[string]any{
|
||||
"const": "send-keys",
|
||||
"description": "Send special keys to PTY session. Keys: down/up/left/right/enter/escape/tab/backspace/ctrl-c/ctrl-d/ctrl-z. Multiple keys separated by comma",
|
||||
},
|
||||
"sessionId": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Session ID returned from background command",
|
||||
},
|
||||
"keys": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Comma-separated key names (optional spaces around comma). Valid keys: up, down, left, right, enter, tab, escape, backspace, ctrl-c, ctrl-d, home, end, pageup, pagedown, f1-f12.",
|
||||
},
|
||||
},
|
||||
"required": []string{"action", "sessionId", "keys"},
|
||||
"working_dir": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Optional working directory for the command",
|
||||
},
|
||||
},
|
||||
"required": []string{"command"},
|
||||
}
|
||||
}
|
||||
|
||||
func (t *ExecTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
|
||||
action, _ := args["action"].(string)
|
||||
if action == "" {
|
||||
return ErrorResult("action is required")
|
||||
}
|
||||
|
||||
switch action {
|
||||
case "run":
|
||||
return t.executeRun(ctx, args)
|
||||
case "list":
|
||||
return t.executeList()
|
||||
case "poll":
|
||||
return t.executePoll(args)
|
||||
case "read":
|
||||
return t.executeRead(args)
|
||||
case "write":
|
||||
return t.executeWrite(args)
|
||||
case "kill":
|
||||
return t.executeKill(args)
|
||||
case "send-keys":
|
||||
return t.executeSendKeys(args)
|
||||
default:
|
||||
return ErrorResult(fmt.Sprintf("unknown action: %s", action))
|
||||
}
|
||||
}
|
||||
|
||||
func (t *ExecTool) executeRun(ctx context.Context, args map[string]any) *ToolResult {
|
||||
command, ok := args["command"].(string)
|
||||
if !ok {
|
||||
return ErrorResult("command is required")
|
||||
@@ -344,26 +206,8 @@ func (t *ExecTool) executeRun(ctx context.Context, args map[string]any) *ToolRes
|
||||
}
|
||||
}
|
||||
|
||||
getBoolArg := func(key string) bool {
|
||||
switch v := args[key].(type) {
|
||||
case bool:
|
||||
return v
|
||||
case string:
|
||||
return v == "true"
|
||||
}
|
||||
return false
|
||||
}
|
||||
isPty := getBoolArg("pty")
|
||||
isBackground := getBoolArg("background")
|
||||
|
||||
if isPty {
|
||||
if runtime.GOOS == "windows" {
|
||||
return ErrorResult("PTY is not supported on Windows. Use background=true without pty.")
|
||||
}
|
||||
}
|
||||
|
||||
cwd := t.workingDir
|
||||
if wd, ok := args["cwd"].(string); ok && wd != "" {
|
||||
if wd, ok := args["working_dir"].(string); ok && wd != "" {
|
||||
if t.restrictToWorkspace && t.workingDir != "" {
|
||||
resolvedWD, err := validatePathWithAllowPaths(wd, t.workingDir, true, t.allowedPathPatterns)
|
||||
if err != nil {
|
||||
@@ -409,14 +253,6 @@ func (t *ExecTool) executeRun(ctx context.Context, args map[string]any) *ToolRes
|
||||
}
|
||||
}
|
||||
|
||||
if isBackground {
|
||||
return t.runBackground(ctx, command, cwd, isPty)
|
||||
}
|
||||
|
||||
return t.runSync(ctx, command, cwd)
|
||||
}
|
||||
|
||||
func (t *ExecTool) runSync(ctx context.Context, command, cwd string) *ToolResult {
|
||||
// timeout == 0 means no timeout
|
||||
var cmdCtx context.Context
|
||||
var cancel context.CancelFunc
|
||||
@@ -525,560 +361,6 @@ func (t *ExecTool) runSync(ctx context.Context, command, cwd string) *ToolResult
|
||||
}
|
||||
}
|
||||
|
||||
func (t *ExecTool) runBackground(ctx context.Context, command, cwd string, ptyEnabled bool) *ToolResult {
|
||||
sessionID := generateSessionID()
|
||||
session := &ProcessSession{
|
||||
ID: sessionID,
|
||||
Command: command,
|
||||
PTY: ptyEnabled,
|
||||
Background: true,
|
||||
StartTime: time.Now().Unix(),
|
||||
Status: "running",
|
||||
ptyKeyMode: PtyKeyModeCSI,
|
||||
}
|
||||
|
||||
var cmd *exec.Cmd
|
||||
if runtime.GOOS == "windows" {
|
||||
cmd = exec.Command("powershell", "-NoProfile", "-NonInteractive", "-Command", command)
|
||||
} else {
|
||||
cmd = exec.Command("sh", "-c", command)
|
||||
}
|
||||
if cwd != "" {
|
||||
cmd.Dir = cwd
|
||||
}
|
||||
|
||||
prepareCommandForTermination(cmd)
|
||||
|
||||
var stdoutReader io.ReadCloser
|
||||
var stderrReader io.ReadCloser
|
||||
var stdinWriter io.WriteCloser
|
||||
|
||||
if ptyEnabled {
|
||||
ptmx, tty, err := pty.Open()
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("failed to create PTY: %v", err))
|
||||
}
|
||||
|
||||
cmd.Stdin = tty
|
||||
cmd.Stdout = tty
|
||||
cmd.Stderr = tty
|
||||
|
||||
// For PTY, we need Setsid to create a new session.
|
||||
// Note: Setsid and Setpgid conflict, so we must replace SysProcAttr entirely.
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{Setsid: true}
|
||||
|
||||
session.ptyMaster = ptmx
|
||||
} else {
|
||||
var err error
|
||||
stdoutReader, err = cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("failed to create stdout pipe: %v", err))
|
||||
}
|
||||
stderrReader, err = cmd.StderrPipe()
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("failed to create stderr pipe: %v", err))
|
||||
}
|
||||
stdinWriter, err = cmd.StdinPipe()
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("failed to create stdin pipe: %v", err))
|
||||
}
|
||||
session.stdoutPipe = io.MultiReader(stdoutReader, stderrReader)
|
||||
session.stdinWriter = stdinWriter
|
||||
}
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
if session.ptyMaster != nil {
|
||||
session.ptyMaster.Close()
|
||||
}
|
||||
return ErrorResult(fmt.Sprintf("failed to start command: %v", err))
|
||||
}
|
||||
|
||||
session.PID = cmd.Process.Pid
|
||||
t.sessionManager.Add(session)
|
||||
|
||||
session.outputBuffer = &bytes.Buffer{}
|
||||
|
||||
// PTY mode: read from ptyMaster and wait for process
|
||||
// Note: On Linux, closing ptyMaster doesn't interrupt blocking Read() calls,
|
||||
// so we need cmd.Wait() in a separate goroutine to detect process exit.
|
||||
if session.PTY && session.ptyMaster != nil {
|
||||
go func() {
|
||||
cmd.Wait() // Wait for process to exit
|
||||
session.mu.Lock()
|
||||
if cmd.ProcessState != nil {
|
||||
session.ExitCode = cmd.ProcessState.ExitCode()
|
||||
}
|
||||
session.Status = "done"
|
||||
session.mu.Unlock()
|
||||
}()
|
||||
|
||||
go func() {
|
||||
buf := make([]byte, 4096)
|
||||
for {
|
||||
n, err := session.ptyMaster.Read(buf)
|
||||
if n > 0 {
|
||||
raw := string(buf[:n])
|
||||
if mode := detectPtyKeyMode(raw); mode != PtyKeyModeNotFound && mode != session.GetPtyKeyMode() {
|
||||
session.SetPtyKeyMode(mode)
|
||||
}
|
||||
|
||||
session.mu.Lock()
|
||||
if session.outputBuffer.Len() >= maxOutputBufferSize {
|
||||
if !session.outputTruncated {
|
||||
session.outputBuffer.WriteString(outputTruncateMarker)
|
||||
session.outputTruncated = true
|
||||
}
|
||||
} else {
|
||||
session.outputBuffer.Write(buf[:n])
|
||||
}
|
||||
session.mu.Unlock()
|
||||
}
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
}()
|
||||
} else {
|
||||
// Non-PTY mode: single goroutine reads pipes.
|
||||
// When Read() returns EOF (pipe closed), we break.
|
||||
// When process exits, OS closes pipe write end → Read() returns EOF → we exit.
|
||||
go func() {
|
||||
buf := make([]byte, 4096)
|
||||
|
||||
// Read stdout
|
||||
for {
|
||||
n, err := stdoutReader.Read(buf)
|
||||
if n > 0 {
|
||||
session.mu.Lock()
|
||||
if session.outputBuffer.Len() >= maxOutputBufferSize {
|
||||
if !session.outputTruncated {
|
||||
session.outputBuffer.WriteString(outputTruncateMarker)
|
||||
session.outputTruncated = true
|
||||
}
|
||||
} else {
|
||||
session.outputBuffer.Write(buf[:n])
|
||||
}
|
||||
session.mu.Unlock()
|
||||
}
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Read stderr
|
||||
for {
|
||||
n, err := stderrReader.Read(buf)
|
||||
if n > 0 {
|
||||
session.mu.Lock()
|
||||
if session.outputBuffer.Len() >= maxOutputBufferSize {
|
||||
if !session.outputTruncated {
|
||||
session.outputBuffer.WriteString(outputTruncateMarker)
|
||||
session.outputTruncated = true
|
||||
}
|
||||
} else {
|
||||
session.outputBuffer.Write(buf[:n])
|
||||
}
|
||||
session.mu.Unlock()
|
||||
}
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// All pipes closed, get exit status
|
||||
if stdinWriter != nil {
|
||||
stdinWriter.Close()
|
||||
}
|
||||
cmd.Wait()
|
||||
|
||||
session.mu.Lock()
|
||||
if cmd.ProcessState != nil {
|
||||
session.ExitCode = cmd.ProcessState.ExitCode()
|
||||
}
|
||||
session.Status = "done"
|
||||
session.mu.Unlock()
|
||||
}()
|
||||
}
|
||||
|
||||
resp := ExecResponse{
|
||||
SessionID: sessionID,
|
||||
Status: "running",
|
||||
}
|
||||
data, _ := json.Marshal(resp)
|
||||
return &ToolResult{
|
||||
ForLLM: string(data),
|
||||
ForUser: fmt.Sprintf("Session %s started", sessionID),
|
||||
IsError: false,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *ExecTool) executeList() *ToolResult {
|
||||
sessions := t.sessionManager.List()
|
||||
resp := ExecResponse{
|
||||
Sessions: sessions,
|
||||
}
|
||||
data, _ := json.Marshal(resp)
|
||||
return &ToolResult{
|
||||
ForLLM: string(data),
|
||||
ForUser: fmt.Sprintf("%d active sessions", len(sessions)),
|
||||
IsError: false,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *ExecTool) executePoll(args map[string]any) *ToolResult {
|
||||
sessionID, ok := args["sessionId"].(string)
|
||||
if !ok {
|
||||
return ErrorResult("sessionId is required")
|
||||
}
|
||||
|
||||
session, err := t.sessionManager.Get(sessionID)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrSessionNotFound) {
|
||||
return ErrorResult(fmt.Sprintf("session not found: %s", sessionID))
|
||||
}
|
||||
return ErrorResult(err.Error())
|
||||
}
|
||||
|
||||
resp := ExecResponse{
|
||||
SessionID: sessionID,
|
||||
Status: session.GetStatus(),
|
||||
ExitCode: session.GetExitCode(),
|
||||
}
|
||||
data, _ := json.Marshal(resp)
|
||||
return &ToolResult{
|
||||
ForLLM: string(data),
|
||||
IsError: false,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *ExecTool) executeRead(args map[string]any) *ToolResult {
|
||||
sessionID, ok := args["sessionId"].(string)
|
||||
if !ok {
|
||||
return ErrorResult("sessionId is required")
|
||||
}
|
||||
|
||||
session, err := t.sessionManager.Get(sessionID)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrSessionNotFound) {
|
||||
return ErrorResult(fmt.Sprintf("session not found: %s", sessionID))
|
||||
}
|
||||
return ErrorResult(err.Error())
|
||||
}
|
||||
|
||||
output := session.Read()
|
||||
|
||||
resp := ExecResponse{
|
||||
SessionID: sessionID,
|
||||
Output: output,
|
||||
Status: session.GetStatus(),
|
||||
}
|
||||
data, _ := json.Marshal(resp)
|
||||
return &ToolResult{
|
||||
ForLLM: string(data),
|
||||
IsError: false,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *ExecTool) executeWrite(args map[string]any) *ToolResult {
|
||||
sessionID, ok := args["sessionId"].(string)
|
||||
if !ok {
|
||||
return ErrorResult("sessionId is required")
|
||||
}
|
||||
|
||||
data, ok := args["data"].(string)
|
||||
if !ok {
|
||||
return ErrorResult("data is required")
|
||||
}
|
||||
|
||||
session, err := t.sessionManager.Get(sessionID)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrSessionNotFound) {
|
||||
return ErrorResult(fmt.Sprintf("session not found: %s", sessionID))
|
||||
}
|
||||
return ErrorResult(err.Error())
|
||||
}
|
||||
|
||||
if session.IsDone() {
|
||||
return ErrorResult(fmt.Sprintf("process already exited with code %d", session.GetExitCode()))
|
||||
}
|
||||
|
||||
if err := session.Write(data); err != nil {
|
||||
if errors.Is(err, ErrSessionDone) {
|
||||
return ErrorResult(fmt.Sprintf("process already exited with code %d", session.GetExitCode()))
|
||||
}
|
||||
return ErrorResult(fmt.Sprintf("failed to write to session: %v", err))
|
||||
}
|
||||
|
||||
resp := ExecResponse{
|
||||
SessionID: sessionID,
|
||||
Status: session.GetStatus(),
|
||||
}
|
||||
respData, _ := json.Marshal(resp)
|
||||
return &ToolResult{
|
||||
ForLLM: string(respData),
|
||||
IsError: false,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *ExecTool) executeKill(args map[string]any) *ToolResult {
|
||||
sessionID, ok := args["sessionId"].(string)
|
||||
if !ok {
|
||||
return ErrorResult("sessionId is required")
|
||||
}
|
||||
|
||||
session, err := t.sessionManager.Get(sessionID)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrSessionNotFound) {
|
||||
return ErrorResult(fmt.Sprintf("session not found: %s", sessionID))
|
||||
}
|
||||
return ErrorResult(err.Error())
|
||||
}
|
||||
|
||||
if session.IsDone() {
|
||||
return ErrorResult(fmt.Sprintf("process already exited with code %d", session.GetExitCode()))
|
||||
}
|
||||
|
||||
if err := session.Kill(); err != nil {
|
||||
return ErrorResult(fmt.Sprintf("failed to kill session: %v", err))
|
||||
}
|
||||
|
||||
t.sessionManager.Remove(sessionID)
|
||||
|
||||
resp := ExecResponse{
|
||||
SessionID: sessionID,
|
||||
Status: "done",
|
||||
}
|
||||
data, _ := json.Marshal(resp)
|
||||
return &ToolResult{
|
||||
ForLLM: string(data),
|
||||
ForUser: fmt.Sprintf("Session %s killed", sessionID),
|
||||
IsError: false,
|
||||
}
|
||||
}
|
||||
|
||||
// keyMap maps key names to their escape sequences.
|
||||
var keyMap = map[string]string{
|
||||
"enter": "\r",
|
||||
"return": "\r",
|
||||
"tab": "\t",
|
||||
"escape": "\x1b",
|
||||
"esc": "\x1b",
|
||||
"space": " ",
|
||||
"backspace": "\x7f",
|
||||
"bspace": "\x7f",
|
||||
"up": "\x1b[A",
|
||||
"down": "\x1b[B",
|
||||
"right": "\x1b[C",
|
||||
"left": "\x1b[D",
|
||||
"home": "\x1b[1~",
|
||||
"end": "\x1b[4~",
|
||||
"pageup": "\x1b[5~",
|
||||
"pagedown": "\x1b[6~",
|
||||
"pgup": "\x1b[5~",
|
||||
"pgdn": "\x1b[6~",
|
||||
"insert": "\x1b[2~",
|
||||
"ic": "\x1b[2~",
|
||||
"delete": "\x1b[3~",
|
||||
"del": "\x1b[3~",
|
||||
"dc": "\x1b[3~",
|
||||
"btab": "\x1b[Z",
|
||||
"f1": "\x1bOP",
|
||||
"f2": "\x1bOQ",
|
||||
"f3": "\x1bOR",
|
||||
"f4": "\x1bOS",
|
||||
"f5": "\x1b[15~",
|
||||
"f6": "\x1b[17~",
|
||||
"f7": "\x1b[18~",
|
||||
"f8": "\x1b[19~",
|
||||
"f9": "\x1b[20~",
|
||||
"f10": "\x1b[21~",
|
||||
"f11": "\x1b[23~",
|
||||
"f12": "\x1b[24~",
|
||||
}
|
||||
|
||||
// ss3KeysMap maps key names to SS3 escape sequences
|
||||
var ss3KeysMap = map[string]string{
|
||||
"up": "\x1bOA",
|
||||
"down": "\x1bOB",
|
||||
"right": "\x1bOC",
|
||||
"left": "\x1bOD",
|
||||
"home": "\x1bOH",
|
||||
"end": "\x1bOF",
|
||||
}
|
||||
|
||||
func detectPtyKeyMode(raw string) PtyKeyMode {
|
||||
const SMKX = "\x1b[?1h"
|
||||
const RMKX = "\x1b[?1l"
|
||||
|
||||
lastSmkx := strings.LastIndex(raw, SMKX)
|
||||
lastRmkx := strings.LastIndex(raw, RMKX)
|
||||
|
||||
if lastSmkx == -1 && lastRmkx == -1 {
|
||||
return PtyKeyModeNotFound
|
||||
}
|
||||
|
||||
if lastSmkx > lastRmkx {
|
||||
return PtyKeyModeSS3
|
||||
}
|
||||
return PtyKeyModeCSI
|
||||
}
|
||||
|
||||
// encodeKeyToken encodes a single key token into its escape sequence.
|
||||
// Supports:
|
||||
// - Named keys: "enter", "tab", "up", "ctrl-c", "alt-x", etc.
|
||||
// - Ctrl modifier: "ctrl-c" or "c-c" (sends Ctrl+char)
|
||||
// - Alt modifier: "alt-x" or "m-x" (sends ESC+char)
|
||||
func encodeKeyToken(token string, ptyKeyMode PtyKeyMode) (string, error) {
|
||||
token = strings.ToLower(strings.TrimSpace(token))
|
||||
if token == "" {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// Handle ctrl-X format (c-x)
|
||||
if strings.HasPrefix(token, "c-") {
|
||||
char := token[2]
|
||||
if char >= 'a' && char <= 'z' {
|
||||
return string(rune(char) & 0x1f), nil // ctrl-a through ctrl-z
|
||||
}
|
||||
return "", fmt.Errorf("invalid ctrl key: %s", token)
|
||||
}
|
||||
|
||||
// Handle ctrl-X format (ctrl-x)
|
||||
if strings.HasPrefix(token, "ctrl-") {
|
||||
char := token[5]
|
||||
if char >= 'a' && char <= 'z' {
|
||||
return string(rune(char) & 0x1f), nil
|
||||
}
|
||||
return "", fmt.Errorf("invalid ctrl key: %s", token)
|
||||
}
|
||||
|
||||
// Handle alt-X format (m-x or alt-x)
|
||||
if strings.HasPrefix(token, "m-") || strings.HasPrefix(token, "alt-") {
|
||||
var char string
|
||||
if strings.HasPrefix(token, "m-") {
|
||||
char = token[2:]
|
||||
} else {
|
||||
char = token[4:]
|
||||
}
|
||||
if len(char) == 1 {
|
||||
return "\x1b" + char, nil
|
||||
}
|
||||
return "", fmt.Errorf("invalid alt key: %s", token)
|
||||
}
|
||||
|
||||
// Handle shift modifier for special keys (shift-up, shift-down, etc.)
|
||||
if strings.HasPrefix(token, "s-") || strings.HasPrefix(token, "shift-") {
|
||||
var key string
|
||||
if strings.HasPrefix(token, "s-") {
|
||||
key = token[2:]
|
||||
} else {
|
||||
key = token[6:]
|
||||
}
|
||||
// Apply shift modifier: for single-char keys, return uppercase
|
||||
if seq, ok := keyMap[key]; ok {
|
||||
// For escape sequences, we can't easily add shift
|
||||
// For single-char keys (letters), return uppercase
|
||||
if len(seq) == 1 {
|
||||
return strings.ToUpper(seq), nil
|
||||
}
|
||||
return seq, nil
|
||||
}
|
||||
return "", fmt.Errorf("unknown key with shift: %s", key)
|
||||
}
|
||||
|
||||
if ptyKeyMode == PtyKeyModeSS3 {
|
||||
if seq, ok := ss3KeysMap[token]; ok {
|
||||
return seq, nil
|
||||
}
|
||||
}
|
||||
|
||||
if seq, ok := keyMap[token]; ok {
|
||||
return seq, nil
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("unknown key: %s (use write action for text input)", token)
|
||||
}
|
||||
|
||||
// encodeKeySequence encodes a slice of key tokens into a single string.
|
||||
func encodeKeySequence(tokens []string, ptyKeyMode PtyKeyMode) (string, error) {
|
||||
var result string
|
||||
for _, token := range tokens {
|
||||
seq, err := encodeKeyToken(token, ptyKeyMode)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
result += seq
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (t *ExecTool) executeSendKeys(args map[string]any) *ToolResult {
|
||||
sessionID, ok := args["sessionId"].(string)
|
||||
if !ok {
|
||||
return ErrorResult("sessionId is required")
|
||||
}
|
||||
|
||||
keysStr, ok := args["keys"].(string)
|
||||
if !ok {
|
||||
return ErrorResult("keys must be a string")
|
||||
}
|
||||
|
||||
if keysStr == "" {
|
||||
return ErrorResult("keys cannot be empty")
|
||||
}
|
||||
|
||||
// Parse comma-separated key names
|
||||
keyNames := strings.Split(keysStr, ",")
|
||||
var keys []string
|
||||
for _, k := range keyNames {
|
||||
k = strings.TrimSpace(k)
|
||||
if k != "" {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
}
|
||||
|
||||
if len(keys) == 0 {
|
||||
return ErrorResult("keys cannot be empty")
|
||||
}
|
||||
|
||||
session, err := t.sessionManager.Get(sessionID)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrSessionNotFound) {
|
||||
return ErrorResult(fmt.Sprintf("session not found: %s", sessionID))
|
||||
}
|
||||
return ErrorResult(err.Error())
|
||||
}
|
||||
|
||||
ptyKeyMode := session.GetPtyKeyMode()
|
||||
|
||||
data, err := encodeKeySequence(keys, ptyKeyMode)
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("invalid key: %v", err))
|
||||
}
|
||||
|
||||
if session.IsDone() {
|
||||
return ErrorResult(fmt.Sprintf("process already exited with code %d", session.GetExitCode()))
|
||||
}
|
||||
|
||||
if err := session.Write(data); err != nil {
|
||||
if errors.Is(err, ErrSessionDone) {
|
||||
return ErrorResult(fmt.Sprintf("process already exited with code %d", session.GetExitCode()))
|
||||
}
|
||||
return ErrorResult(fmt.Sprintf("failed to send keys: %v", err))
|
||||
}
|
||||
|
||||
resp := ExecResponse{
|
||||
SessionID: sessionID,
|
||||
Status: "running",
|
||||
Output: fmt.Sprintf("Sent keys: %v", keys),
|
||||
}
|
||||
respData, _ := json.Marshal(resp)
|
||||
return &ToolResult{
|
||||
ForLLM: string(respData),
|
||||
IsError: false,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *ExecTool) guardCommand(command, cwd string) string {
|
||||
cmd := strings.TrimSpace(command)
|
||||
lower := strings.ToLower(cmd)
|
||||
|
||||
+16
-946
File diff suppressed because it is too large
Load Diff
@@ -30,7 +30,6 @@ func TestShellTool_TimeoutKillsChildProcess(t *testing.T) {
|
||||
tool.SetTimeout(500 * time.Millisecond)
|
||||
|
||||
args := map[string]any{
|
||||
"action": "run",
|
||||
// Spawn a child process that would outlive the shell unless process-group kill is used.
|
||||
"command": "sleep 60 & echo $! > child.pid; wait",
|
||||
}
|
||||
|
||||
@@ -56,24 +56,3 @@ type ToolFunctionDefinition struct {
|
||||
Description string `json:"description"`
|
||||
Parameters map[string]any `json:"parameters"`
|
||||
}
|
||||
|
||||
type ExecRequest struct {
|
||||
Action string `json:"action"`
|
||||
Command string `json:"command,omitempty"`
|
||||
PTY bool `json:"pty,omitempty"`
|
||||
Background bool `json:"background,omitempty"`
|
||||
Timeout int `json:"timeout,omitempty"`
|
||||
Env map[string]string `json:"env,omitempty"`
|
||||
Cwd string `json:"cwd,omitempty"`
|
||||
SessionID string `json:"sessionId,omitempty"`
|
||||
Data string `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
type ExecResponse struct {
|
||||
SessionID string `json:"sessionId,omitempty"`
|
||||
Status string `json:"status,omitempty"`
|
||||
ExitCode int `json:"exitCode,omitempty"`
|
||||
Output string `json:"output,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
Sessions []SessionInfo `json:"sessions,omitempty"`
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user