mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
Merge upstream/main into feat/subturn-poc
Includes JSONL session persistence (#1170), spawn_status tool, Azure provider, credential encryption, and various fixes. SubTurn features preserved and integrated with new spawn_status functionality.
This commit is contained in:
+53
-20
@@ -20,10 +20,12 @@ type JobExecutor interface {
|
||||
|
||||
// CronTool provides scheduling capabilities for the agent
|
||||
type CronTool struct {
|
||||
cronService *cron.CronService
|
||||
executor JobExecutor
|
||||
msgBus *bus.MessageBus
|
||||
execTool *ExecTool
|
||||
cronService *cron.CronService
|
||||
executor JobExecutor
|
||||
msgBus *bus.MessageBus
|
||||
execTool *ExecTool
|
||||
allowCommand bool
|
||||
execEnabled bool
|
||||
}
|
||||
|
||||
// NewCronTool creates a new CronTool
|
||||
@@ -32,17 +34,32 @@ func NewCronTool(
|
||||
cronService *cron.CronService, executor JobExecutor, msgBus *bus.MessageBus, workspace string, restrict bool,
|
||||
execTimeout time.Duration, config *config.Config,
|
||||
) (*CronTool, error) {
|
||||
execTool, err := NewExecToolWithConfig(workspace, restrict, config)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to configure exec tool: %w", err)
|
||||
allowCommand := true
|
||||
execEnabled := true
|
||||
if config != nil {
|
||||
allowCommand = config.Tools.Cron.AllowCommand
|
||||
execEnabled = config.Tools.Exec.Enabled
|
||||
}
|
||||
|
||||
execTool.SetTimeout(execTimeout)
|
||||
var execTool *ExecTool
|
||||
if execEnabled {
|
||||
var err error
|
||||
execTool, err = NewExecToolWithConfig(workspace, restrict, config)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to configure exec tool: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if execTool != nil {
|
||||
execTool.SetTimeout(execTimeout)
|
||||
}
|
||||
return &CronTool{
|
||||
cronService: cronService,
|
||||
executor: executor,
|
||||
msgBus: msgBus,
|
||||
execTool: execTool,
|
||||
cronService: cronService,
|
||||
executor: executor,
|
||||
msgBus: msgBus,
|
||||
execTool: execTool,
|
||||
allowCommand: allowCommand,
|
||||
execEnabled: execEnabled,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -76,7 +93,7 @@ func (t *CronTool) Parameters() map[string]any {
|
||||
},
|
||||
"command_confirm": map[string]any{
|
||||
"type": "boolean",
|
||||
"description": "Required when using command=true. Must be true to explicitly confirm scheduling a shell command.",
|
||||
"description": "Optional explicit confirmation flag for scheduling a shell command. Command execution must also be enabled via tools.cron.allow_command.",
|
||||
},
|
||||
"at_seconds": map[string]any{
|
||||
"type": "integer",
|
||||
@@ -96,7 +113,7 @@ func (t *CronTool) Parameters() map[string]any {
|
||||
},
|
||||
"deliver": map[string]any{
|
||||
"type": "boolean",
|
||||
"description": "If true, send message directly to channel. If false, let agent process message (for complex tasks). Default: true",
|
||||
"description": "If true, send message directly to channel. If false, let agent process message (for complex tasks). Default: false",
|
||||
},
|
||||
},
|
||||
"required": []string{"action"},
|
||||
@@ -174,22 +191,26 @@ func (t *CronTool) addJob(ctx context.Context, args map[string]any) *ToolResult
|
||||
return ErrorResult("one of at_seconds, every_seconds, or cron_expr is required")
|
||||
}
|
||||
|
||||
// Read deliver parameter, default to true
|
||||
deliver := true
|
||||
// Read deliver parameter, default to false so scheduled tasks execute through the agent
|
||||
deliver := false
|
||||
if d, ok := args["deliver"].(bool); ok {
|
||||
deliver = d
|
||||
}
|
||||
|
||||
// GHSA-pv8c-p6jf-3fpp: command scheduling requires internal channel + explicit confirm.
|
||||
// Non-command reminders (plain messages) remain open to all channels.
|
||||
// GHSA-pv8c-p6jf-3fpp: command scheduling requires internal channel. When
|
||||
// allow_command is disabled, explicit confirmation is required as an override.
|
||||
// Non-command reminders remain open to all channels.
|
||||
command, _ := args["command"].(string)
|
||||
commandConfirm, _ := args["command_confirm"].(bool)
|
||||
if command != "" {
|
||||
if !t.execEnabled {
|
||||
return ErrorResult("command execution is disabled")
|
||||
}
|
||||
if !constants.IsInternalChannel(channel) {
|
||||
return ErrorResult("scheduling command execution is restricted to internal channels")
|
||||
}
|
||||
if !commandConfirm {
|
||||
return ErrorResult("command_confirm=true is required to schedule command execution")
|
||||
if !t.allowCommand && !commandConfirm {
|
||||
return ErrorResult("command_confirm=true is required when allow_command is disabled")
|
||||
}
|
||||
deliver = false
|
||||
}
|
||||
@@ -290,6 +311,18 @@ func (t *CronTool) ExecuteJob(ctx context.Context, job *cron.CronJob) string {
|
||||
|
||||
// Execute command if present
|
||||
if job.Payload.Command != "" {
|
||||
if !t.execEnabled || t.execTool == nil {
|
||||
output := "Error executing scheduled command: command execution is disabled"
|
||||
pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer pubCancel()
|
||||
t.msgBus.PublishOutbound(pubCtx, bus.OutboundMessage{
|
||||
Channel: channel,
|
||||
ChatID: chatID,
|
||||
Content: output,
|
||||
})
|
||||
return "ok"
|
||||
}
|
||||
|
||||
args := map[string]any{
|
||||
"command": job.Payload.Command,
|
||||
"__channel": channel,
|
||||
|
||||
+126
-6
@@ -5,18 +5,18 @@ import (
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/cron"
|
||||
)
|
||||
|
||||
func newTestCronTool(t *testing.T) *CronTool {
|
||||
func newTestCronToolWithConfig(t *testing.T, cfg *config.Config) *CronTool {
|
||||
t.Helper()
|
||||
storePath := filepath.Join(t.TempDir(), "cron.json")
|
||||
cronService := cron.NewCronService(storePath, nil)
|
||||
msgBus := bus.NewMessageBus()
|
||||
cfg := config.DefaultConfig()
|
||||
tool, err := NewCronTool(cronService, nil, msgBus, t.TempDir(), true, 0, cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("NewCronTool() error: %v", err)
|
||||
@@ -24,6 +24,11 @@ func newTestCronTool(t *testing.T) *CronTool {
|
||||
return tool
|
||||
}
|
||||
|
||||
func newTestCronTool(t *testing.T) *CronTool {
|
||||
t.Helper()
|
||||
return newTestCronToolWithConfig(t, config.DefaultConfig())
|
||||
}
|
||||
|
||||
// TestCronTool_CommandBlockedFromRemoteChannel verifies command scheduling is restricted to internal channels
|
||||
func TestCronTool_CommandBlockedFromRemoteChannel(t *testing.T) {
|
||||
tool := newTestCronTool(t)
|
||||
@@ -44,8 +49,7 @@ func TestCronTool_CommandBlockedFromRemoteChannel(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestCronTool_CommandRequiresConfirm verifies command_confirm=true is required
|
||||
func TestCronTool_CommandRequiresConfirm(t *testing.T) {
|
||||
func TestCronTool_CommandDoesNotRequireConfirmByDefault(t *testing.T) {
|
||||
tool := newTestCronTool(t)
|
||||
ctx := WithToolContext(context.Background(), "cli", "direct")
|
||||
result := tool.Execute(ctx, map[string]any{
|
||||
@@ -55,11 +59,79 @@ func TestCronTool_CommandRequiresConfirm(t *testing.T) {
|
||||
"at_seconds": float64(60),
|
||||
})
|
||||
|
||||
if result.IsError {
|
||||
t.Fatalf("expected command scheduling without confirm to succeed by default, got: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "Cron job added") {
|
||||
t.Errorf("expected 'Cron job added', got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCronTool_CommandRequiresConfirmWhenAllowCommandDisabled(t *testing.T) {
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Tools.Cron.AllowCommand = false
|
||||
|
||||
tool := newTestCronToolWithConfig(t, cfg)
|
||||
ctx := WithToolContext(context.Background(), "cli", "direct")
|
||||
result := tool.Execute(ctx, map[string]any{
|
||||
"action": "add",
|
||||
"message": "check disk",
|
||||
"command": "df -h",
|
||||
"at_seconds": float64(60),
|
||||
})
|
||||
|
||||
if !result.IsError {
|
||||
t.Fatal("expected error when command_confirm is missing")
|
||||
t.Fatal("expected command scheduling to require confirm when allow_command is disabled")
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "command_confirm=true") {
|
||||
t.Errorf("expected 'command_confirm=true' message, got: %s", result.ForLLM)
|
||||
t.Errorf("expected command_confirm requirement message, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCronTool_CommandAllowedWithConfirmWhenAllowCommandDisabled(t *testing.T) {
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Tools.Cron.AllowCommand = false
|
||||
|
||||
tool := newTestCronToolWithConfig(t, cfg)
|
||||
ctx := WithToolContext(context.Background(), "cli", "direct")
|
||||
result := tool.Execute(ctx, map[string]any{
|
||||
"action": "add",
|
||||
"message": "check disk",
|
||||
"command": "df -h",
|
||||
"command_confirm": true,
|
||||
"at_seconds": float64(60),
|
||||
})
|
||||
|
||||
if result.IsError {
|
||||
t.Fatalf(
|
||||
"expected command scheduling with confirm to succeed when allow_command is disabled, got: %s",
|
||||
result.ForLLM,
|
||||
)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "Cron job added") {
|
||||
t.Errorf("expected 'Cron job added', got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCronTool_CommandBlockedWhenExecDisabled(t *testing.T) {
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Tools.Exec.Enabled = false
|
||||
|
||||
tool := newTestCronToolWithConfig(t, cfg)
|
||||
ctx := WithToolContext(context.Background(), "cli", "direct")
|
||||
result := tool.Execute(ctx, map[string]any{
|
||||
"action": "add",
|
||||
"message": "check disk",
|
||||
"command": "df -h",
|
||||
"command_confirm": true,
|
||||
"at_seconds": float64(60),
|
||||
})
|
||||
|
||||
if !result.IsError {
|
||||
t.Fatal("expected command scheduling to be blocked when exec is disabled")
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "command execution is disabled") {
|
||||
t.Errorf("expected exec disabled message, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -114,3 +186,51 @@ func TestCronTool_NonCommandJobAllowedFromRemoteChannel(t *testing.T) {
|
||||
t.Fatalf("expected non-command reminder to succeed from remote channel, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCronTool_NonCommandJobDefaultsDeliverToFalse(t *testing.T) {
|
||||
tool := newTestCronTool(t)
|
||||
ctx := WithToolContext(context.Background(), "telegram", "chat-1")
|
||||
result := tool.Execute(ctx, map[string]any{
|
||||
"action": "add",
|
||||
"message": "send me a poem",
|
||||
"at_seconds": float64(600),
|
||||
})
|
||||
|
||||
if result.IsError {
|
||||
t.Fatalf("expected non-command reminder to succeed, got: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
jobs := tool.cronService.ListJobs(false)
|
||||
if len(jobs) != 1 {
|
||||
t.Fatalf("expected 1 job, got %d", len(jobs))
|
||||
}
|
||||
if jobs[0].Payload.Deliver {
|
||||
t.Fatal("expected deliver=false by default for non-command jobs")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCronTool_ExecuteJobPublishesErrorWhenExecDisabled(t *testing.T) {
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Tools.Exec.Enabled = false
|
||||
|
||||
tool := newTestCronToolWithConfig(t, cfg)
|
||||
job := &cron.CronJob{}
|
||||
job.Payload.Channel = "cli"
|
||||
job.Payload.To = "direct"
|
||||
job.Payload.Command = "df -h"
|
||||
|
||||
if got := tool.ExecuteJob(context.Background(), job); got != "ok" {
|
||||
t.Fatalf("ExecuteJob() = %q, want ok", got)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
|
||||
msg, ok := tool.msgBus.SubscribeOutbound(ctx)
|
||||
if !ok {
|
||||
t.Fatal("expected outbound message")
|
||||
}
|
||||
if !strings.Contains(msg.Content, "command execution is disabled") {
|
||||
t.Fatalf("expected exec disabled message, got: %s", msg.Content)
|
||||
}
|
||||
}
|
||||
|
||||
+161
-9
@@ -20,8 +20,7 @@ import (
|
||||
|
||||
const MaxReadFileSize = 64 * 1024 // 64KB limit to avoid context overflow
|
||||
|
||||
// validatePath ensures the given path is within the workspace if restrict is true.
|
||||
func validatePath(path, workspace string, restrict bool) (string, error) {
|
||||
func validatePathWithAllowPaths(path, workspace string, restrict bool, patterns []*regexp.Regexp) (string, error) {
|
||||
if workspace == "" {
|
||||
return path, fmt.Errorf("workspace is not defined")
|
||||
}
|
||||
@@ -42,6 +41,10 @@ func validatePath(path, workspace string, restrict bool) (string, error) {
|
||||
}
|
||||
|
||||
if restrict {
|
||||
if isAllowedPath(absPath, patterns) {
|
||||
return absPath, nil
|
||||
}
|
||||
|
||||
if !isWithinWorkspace(absPath, absWorkspace) {
|
||||
return "", fmt.Errorf("access denied: path is outside the workspace")
|
||||
}
|
||||
@@ -73,6 +76,137 @@ func validatePath(path, workspace string, restrict bool) (string, error) {
|
||||
return absPath, nil
|
||||
}
|
||||
|
||||
func isAllowedPath(path string, patterns []*regexp.Regexp) bool {
|
||||
if len(patterns) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
cleaned := filepath.Clean(path)
|
||||
if !filepath.IsAbs(cleaned) {
|
||||
return false
|
||||
}
|
||||
if !matchesAllowedPath(cleaned, patterns) {
|
||||
return false
|
||||
}
|
||||
|
||||
resolved, err := resolvePathAgainstExistingAncestor(cleaned)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return matchesAllowedPath(resolved, patterns)
|
||||
}
|
||||
|
||||
func matchesAllowedPath(path string, patterns []*regexp.Regexp) bool {
|
||||
cleaned := filepath.Clean(path)
|
||||
for _, pattern := range patterns {
|
||||
if pattern.MatchString(cleaned) {
|
||||
return true
|
||||
}
|
||||
if root, ok := extractAllowedPathRoot(pattern); ok && isWithinAllowedRoot(cleaned, root) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func extractAllowedPathRoot(pattern *regexp.Regexp) (string, bool) {
|
||||
raw := pattern.String()
|
||||
if !strings.HasPrefix(raw, "^") {
|
||||
return "", false
|
||||
}
|
||||
|
||||
literal := strings.TrimPrefix(raw, "^")
|
||||
|
||||
// Recognize the common "directory prefix" form: ^<literal>(?:/|$)
|
||||
literal = strings.TrimSuffix(literal, "(?:/|$)")
|
||||
literal = strings.TrimSuffix(literal, `(?:\\|$)`)
|
||||
|
||||
// Reject patterns that still contain regex operators after removing the
|
||||
// optional anchored-directory suffix. That keeps arbitrary regex behavior
|
||||
// unchanged and only enables normalized prefix matching for literal paths.
|
||||
if containsUnescapedRegexMeta(literal) {
|
||||
return "", false
|
||||
}
|
||||
|
||||
unescaped, ok := unescapeRegexLiteral(literal)
|
||||
if !ok || unescaped == "" {
|
||||
return "", false
|
||||
}
|
||||
|
||||
return filepath.Clean(unescaped), filepath.IsAbs(unescaped)
|
||||
}
|
||||
|
||||
func appendUniquePath(paths []string, path string) []string {
|
||||
for _, existing := range paths {
|
||||
if existing == path {
|
||||
return paths
|
||||
}
|
||||
}
|
||||
return append(paths, path)
|
||||
}
|
||||
|
||||
func containsUnescapedRegexMeta(s string) bool {
|
||||
escaped := false
|
||||
for _, r := range s {
|
||||
if escaped {
|
||||
escaped = false
|
||||
continue
|
||||
}
|
||||
if r == '\\' {
|
||||
escaped = true
|
||||
continue
|
||||
}
|
||||
switch r {
|
||||
case '.', '+', '*', '?', '(', ')', '[', ']', '{', '}', '|':
|
||||
return true
|
||||
}
|
||||
}
|
||||
return escaped
|
||||
}
|
||||
|
||||
func unescapeRegexLiteral(s string) (string, bool) {
|
||||
var b strings.Builder
|
||||
b.Grow(len(s))
|
||||
|
||||
escaped := false
|
||||
for _, r := range s {
|
||||
if escaped {
|
||||
b.WriteRune(r)
|
||||
escaped = false
|
||||
continue
|
||||
}
|
||||
if r == '\\' {
|
||||
escaped = true
|
||||
continue
|
||||
}
|
||||
b.WriteRune(r)
|
||||
}
|
||||
|
||||
if escaped {
|
||||
return "", false
|
||||
}
|
||||
|
||||
return b.String(), true
|
||||
}
|
||||
|
||||
func isWithinAllowedRoot(path, root string) bool {
|
||||
candidate := filepath.Clean(path)
|
||||
allowedVariants := []string{filepath.Clean(root)}
|
||||
|
||||
if resolvedRoot, err := resolvePathAgainstExistingAncestor(root); err == nil {
|
||||
allowedVariants = appendUniquePath(allowedVariants, filepath.Clean(resolvedRoot))
|
||||
}
|
||||
|
||||
for _, allowedRoot := range allowedVariants {
|
||||
if isWithinWorkspace(candidate, allowedRoot) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func resolveExistingAncestor(path string) (string, error) {
|
||||
for current := filepath.Clean(path); ; current = filepath.Dir(current) {
|
||||
if resolved, err := filepath.EvalSymlinks(current); err == nil {
|
||||
@@ -86,9 +220,32 @@ func resolveExistingAncestor(path string) (string, error) {
|
||||
}
|
||||
}
|
||||
|
||||
func resolvePathAgainstExistingAncestor(path string) (string, error) {
|
||||
cleaned := filepath.Clean(path)
|
||||
for current := cleaned; ; current = filepath.Dir(current) {
|
||||
resolved, err := filepath.EvalSymlinks(current)
|
||||
if err == nil {
|
||||
suffix, relErr := filepath.Rel(current, cleaned)
|
||||
if relErr != nil {
|
||||
return "", relErr
|
||||
}
|
||||
if suffix == "." {
|
||||
return filepath.Clean(resolved), nil
|
||||
}
|
||||
return filepath.Clean(filepath.Join(resolved, suffix)), nil
|
||||
}
|
||||
if !os.IsNotExist(err) {
|
||||
return "", err
|
||||
}
|
||||
if filepath.Dir(current) == current {
|
||||
return "", os.ErrNotExist
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func isWithinWorkspace(candidate, workspace string) bool {
|
||||
rel, err := filepath.Rel(filepath.Clean(workspace), filepath.Clean(candidate))
|
||||
return err == nil && filepath.IsLocal(rel)
|
||||
return err == nil && (rel == "." || filepath.IsLocal(rel))
|
||||
}
|
||||
|
||||
type ReadFileTool struct {
|
||||
@@ -625,12 +782,7 @@ type whitelistFs struct {
|
||||
}
|
||||
|
||||
func (w *whitelistFs) matches(path string) bool {
|
||||
for _, p := range w.patterns {
|
||||
if p.MatchString(path) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
return isAllowedPath(path, w.patterns)
|
||||
}
|
||||
|
||||
func (w *whitelistFs) ReadFile(path string) ([]byte, error) {
|
||||
|
||||
@@ -521,6 +521,90 @@ func TestWhitelistFs_AllowsMatchingPaths(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestWhitelistFs_BlocksSymlinkEscapeInAllowedDir(t *testing.T) {
|
||||
workspace := t.TempDir()
|
||||
allowedDir := t.TempDir()
|
||||
secretDir := t.TempDir()
|
||||
secretFile := filepath.Join(secretDir, "secret.txt")
|
||||
if err := os.WriteFile(secretFile, []byte("top secret"), 0o644); err != nil {
|
||||
t.Fatalf("WriteFile(secretFile) error = %v", err)
|
||||
}
|
||||
|
||||
linkPath := filepath.Join(allowedDir, "link_out")
|
||||
if err := os.Symlink(secretDir, linkPath); err != nil {
|
||||
t.Skipf("symlink not supported in this environment: %v", err)
|
||||
}
|
||||
|
||||
patterns := []*regexp.Regexp{regexp.MustCompile(`^` + regexp.QuoteMeta(allowedDir))}
|
||||
tool := NewReadFileTool(workspace, true, MaxReadFileSize, patterns)
|
||||
|
||||
result := tool.Execute(context.Background(), map[string]any{"path": filepath.Join(linkPath, "secret.txt")})
|
||||
if !result.IsError {
|
||||
t.Fatalf("expected symlink escape from allowed dir to be blocked, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWhitelistFs_WriteAllowsNewFileUnderAllowedDir(t *testing.T) {
|
||||
workspace := t.TempDir()
|
||||
rootDir := t.TempDir()
|
||||
allowedDir := filepath.Join(rootDir, "allowed")
|
||||
targetFile := filepath.Join(allowedDir, "nested", "file.txt")
|
||||
|
||||
patterns := []*regexp.Regexp{regexp.MustCompile(`^` + regexp.QuoteMeta(allowedDir))}
|
||||
tool := NewWriteFileTool(workspace, true, patterns)
|
||||
|
||||
result := tool.Execute(context.Background(), map[string]any{
|
||||
"path": targetFile,
|
||||
"content": "outside write",
|
||||
})
|
||||
if result.IsError {
|
||||
t.Fatalf("expected whitelisted write to succeed, got: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(targetFile)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadFile(targetFile) error = %v", err)
|
||||
}
|
||||
if string(data) != "outside write" {
|
||||
t.Fatalf("target file content = %q, want %q", string(data), "outside write")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWhitelistFs_AllowsResolvedAllowedRootAlias(t *testing.T) {
|
||||
workspace := t.TempDir()
|
||||
realDir := t.TempDir()
|
||||
linkParent := t.TempDir()
|
||||
allowedAlias := filepath.Join(linkParent, "allowed-link")
|
||||
|
||||
if err := os.Symlink(realDir, allowedAlias); err != nil {
|
||||
t.Skipf("symlink not supported in this environment: %v", err)
|
||||
}
|
||||
|
||||
targetFile := filepath.Join(allowedAlias, "nested", "alias.txt")
|
||||
if err := os.MkdirAll(filepath.Dir(targetFile), 0o755); err != nil {
|
||||
t.Fatalf("MkdirAll(targetFile dir) error = %v", err)
|
||||
}
|
||||
if err := os.WriteFile(targetFile, []byte("through alias"), 0o644); err != nil {
|
||||
t.Fatalf("WriteFile(targetFile) error = %v", err)
|
||||
}
|
||||
|
||||
patterns := []*regexp.Regexp{
|
||||
regexp.MustCompile(
|
||||
"^" + regexp.QuoteMeta(filepath.Clean(allowedAlias)) +
|
||||
"(?:" + regexp.QuoteMeta(string(os.PathSeparator)) + "|$)",
|
||||
),
|
||||
}
|
||||
tool := NewReadFileTool(workspace, true, MaxReadFileSize, patterns)
|
||||
|
||||
result := tool.Execute(context.Background(), map[string]any{"path": targetFile})
|
||||
if result.IsError {
|
||||
t.Fatalf("expected symlink-backed allowed root to be readable, got: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "through alias") {
|
||||
t.Fatalf("expected file content, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
// TestReadFileTool_ChunkedReading verifies the pagination logic of the tool
|
||||
// by reading a file in multiple chunks using 'offset' and 'length'.
|
||||
func TestReadFileTool_ChunkedReading(t *testing.T) {
|
||||
|
||||
+15
-2
@@ -6,6 +6,7 @@ import (
|
||||
"mime"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/h2non/filetype"
|
||||
@@ -21,20 +22,32 @@ type SendFileTool struct {
|
||||
restrict bool
|
||||
maxFileSize int
|
||||
mediaStore media.MediaStore
|
||||
allowPaths []*regexp.Regexp
|
||||
|
||||
defaultChannel string
|
||||
defaultChatID string
|
||||
}
|
||||
|
||||
func NewSendFileTool(workspace string, restrict bool, maxFileSize int, store media.MediaStore) *SendFileTool {
|
||||
func NewSendFileTool(
|
||||
workspace string,
|
||||
restrict bool,
|
||||
maxFileSize int,
|
||||
store media.MediaStore,
|
||||
allowPaths ...[]*regexp.Regexp,
|
||||
) *SendFileTool {
|
||||
if maxFileSize <= 0 {
|
||||
maxFileSize = config.DefaultMaxMediaSize
|
||||
}
|
||||
var patterns []*regexp.Regexp
|
||||
if len(allowPaths) > 0 {
|
||||
patterns = allowPaths[0]
|
||||
}
|
||||
return &SendFileTool{
|
||||
workspace: workspace,
|
||||
restrict: restrict,
|
||||
maxFileSize: maxFileSize,
|
||||
mediaStore: store,
|
||||
allowPaths: patterns,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -92,7 +105,7 @@ func (t *SendFileTool) Execute(ctx context.Context, args map[string]any) *ToolRe
|
||||
return ErrorResult("media store not configured")
|
||||
}
|
||||
|
||||
resolved, err := validatePath(path, t.workspace, t.restrict)
|
||||
resolved, err := validatePathWithAllowPaths(path, t.workspace, t.restrict, t.allowPaths)
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("invalid path: %v", err))
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
@@ -128,6 +129,44 @@ func TestSendFileTool_CustomFilename(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendFileTool_AllowsWhitelistedMediaTempPath(t *testing.T) {
|
||||
workspace := t.TempDir()
|
||||
mediaDir := media.TempDir()
|
||||
if err := os.MkdirAll(mediaDir, 0o700); err != nil {
|
||||
t.Fatalf("MkdirAll(mediaDir) error = %v", err)
|
||||
}
|
||||
|
||||
testFile, err := os.CreateTemp(mediaDir, "send-file-*.txt")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateTemp(mediaDir) error = %v", err)
|
||||
}
|
||||
testPath := testFile.Name()
|
||||
if _, err := testFile.WriteString("forward me"); err != nil {
|
||||
testFile.Close()
|
||||
t.Fatalf("WriteString(testFile) error = %v", err)
|
||||
}
|
||||
if err := testFile.Close(); err != nil {
|
||||
t.Fatalf("Close(testFile) error = %v", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = os.Remove(testPath) })
|
||||
|
||||
pattern := regexp.MustCompile(
|
||||
"^" + regexp.QuoteMeta(filepath.Clean(mediaDir)) + "(?:" + regexp.QuoteMeta(string(os.PathSeparator)) + "|$)",
|
||||
)
|
||||
|
||||
store := media.NewFileMediaStore()
|
||||
tool := NewSendFileTool(workspace, true, 0, store, []*regexp.Regexp{pattern})
|
||||
tool.SetContext("feishu", "chat123")
|
||||
|
||||
result := tool.Execute(context.Background(), map[string]any{"path": testPath})
|
||||
if result.IsError {
|
||||
t.Fatalf("expected whitelisted temp media file to be sendable, got: %s", result.ForLLM)
|
||||
}
|
||||
if len(result.Media) != 1 {
|
||||
t.Fatalf("expected 1 media ref, got %d", len(result.Media))
|
||||
}
|
||||
}
|
||||
|
||||
func TestDetectMediaType_MagicBytes(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
|
||||
+31
-13
@@ -23,6 +23,7 @@ type ExecTool struct {
|
||||
denyPatterns []*regexp.Regexp
|
||||
allowPatterns []*regexp.Regexp
|
||||
customAllowPatterns []*regexp.Regexp
|
||||
allowedPathPatterns []*regexp.Regexp
|
||||
restrictToWorkspace bool
|
||||
allowRemote bool
|
||||
}
|
||||
@@ -95,14 +96,23 @@ var (
|
||||
}
|
||||
)
|
||||
|
||||
func NewExecTool(workingDir string, restrict bool) (*ExecTool, error) {
|
||||
return NewExecToolWithConfig(workingDir, restrict, nil)
|
||||
func NewExecTool(workingDir string, restrict bool, allowPaths ...[]*regexp.Regexp) (*ExecTool, error) {
|
||||
return NewExecToolWithConfig(workingDir, restrict, nil, allowPaths...)
|
||||
}
|
||||
|
||||
func NewExecToolWithConfig(workingDir string, restrict bool, config *config.Config) (*ExecTool, error) {
|
||||
func NewExecToolWithConfig(
|
||||
workingDir string,
|
||||
restrict bool,
|
||||
config *config.Config,
|
||||
allowPaths ...[]*regexp.Regexp,
|
||||
) (*ExecTool, error) {
|
||||
denyPatterns := make([]*regexp.Regexp, 0)
|
||||
customAllowPatterns := make([]*regexp.Regexp, 0)
|
||||
var allowedPathPatterns []*regexp.Regexp
|
||||
allowRemote := true
|
||||
if len(allowPaths) > 0 {
|
||||
allowedPathPatterns = allowPaths[0]
|
||||
}
|
||||
|
||||
if config != nil {
|
||||
execConfig := config.Tools.Exec
|
||||
@@ -146,6 +156,7 @@ func NewExecToolWithConfig(workingDir string, restrict bool, config *config.Conf
|
||||
denyPatterns: denyPatterns,
|
||||
allowPatterns: nil,
|
||||
customAllowPatterns: customAllowPatterns,
|
||||
allowedPathPatterns: allowedPathPatterns,
|
||||
restrictToWorkspace: restrict,
|
||||
allowRemote: allowRemote,
|
||||
}, nil
|
||||
@@ -198,7 +209,7 @@ func (t *ExecTool) Execute(ctx context.Context, args map[string]any) *ToolResult
|
||||
cwd := t.workingDir
|
||||
if wd, ok := args["working_dir"].(string); ok && wd != "" {
|
||||
if t.restrictToWorkspace && t.workingDir != "" {
|
||||
resolvedWD, err := validatePath(wd, t.workingDir, true)
|
||||
resolvedWD, err := validatePathWithAllowPaths(wd, t.workingDir, true, t.allowedPathPatterns)
|
||||
if err != nil {
|
||||
return ErrorResult("Command blocked by safety guard (" + err.Error() + ")")
|
||||
}
|
||||
@@ -226,16 +237,20 @@ func (t *ExecTool) Execute(ctx context.Context, args map[string]any) *ToolResult
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("Command blocked by safety guard (path resolution failed: %v)", err))
|
||||
}
|
||||
absWorkspace, _ := filepath.Abs(t.workingDir)
|
||||
wsResolved, _ := filepath.EvalSymlinks(absWorkspace)
|
||||
if wsResolved == "" {
|
||||
wsResolved = absWorkspace
|
||||
if isAllowedPath(resolved, t.allowedPathPatterns) {
|
||||
cwd = resolved
|
||||
} else {
|
||||
absWorkspace, _ := filepath.Abs(t.workingDir)
|
||||
wsResolved, _ := filepath.EvalSymlinks(absWorkspace)
|
||||
if wsResolved == "" {
|
||||
wsResolved = absWorkspace
|
||||
}
|
||||
rel, err := filepath.Rel(wsResolved, resolved)
|
||||
if err != nil || !filepath.IsLocal(rel) {
|
||||
return ErrorResult("Command blocked by safety guard (working directory escaped workspace)")
|
||||
}
|
||||
cwd = resolved
|
||||
}
|
||||
rel, err := filepath.Rel(wsResolved, resolved)
|
||||
if err != nil || !filepath.IsLocal(rel) {
|
||||
return ErrorResult("Command blocked by safety guard (working directory escaped workspace)")
|
||||
}
|
||||
cwd = resolved
|
||||
}
|
||||
|
||||
// timeout == 0 means no timeout
|
||||
@@ -412,6 +427,9 @@ func (t *ExecTool) guardCommand(command, cwd string) string {
|
||||
if safePaths[p] {
|
||||
continue
|
||||
}
|
||||
if isAllowedPath(p, t.allowedPathPatterns) {
|
||||
continue
|
||||
}
|
||||
|
||||
rel, err := filepath.Rel(cwdPath, p)
|
||||
if err != nil {
|
||||
|
||||
@@ -0,0 +1,178 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// SpawnStatusTool reports the status of subagents that were spawned via the
|
||||
// spawn tool. It can query a specific task by ID, or list every known task with
|
||||
// a summary count broken-down by status.
|
||||
type SpawnStatusTool struct {
|
||||
manager *SubagentManager
|
||||
}
|
||||
|
||||
// NewSpawnStatusTool creates a SpawnStatusTool backed by the given manager.
|
||||
func NewSpawnStatusTool(manager *SubagentManager) *SpawnStatusTool {
|
||||
return &SpawnStatusTool{manager: manager}
|
||||
}
|
||||
|
||||
func (t *SpawnStatusTool) Name() string {
|
||||
return "spawn_status"
|
||||
}
|
||||
|
||||
func (t *SpawnStatusTool) Description() string {
|
||||
return "Get the status of spawned subagents. " +
|
||||
"Returns a list of all subagents and their current state " +
|
||||
"(running, completed, failed, or canceled), or retrieves details " +
|
||||
"for a specific subagent task when task_id is provided. " +
|
||||
"Results are scoped to the current conversation's channel and chat ID; " +
|
||||
"all tasks are listed only when no channel/chat context is injected " +
|
||||
"(e.g. direct programmatic calls via Execute)."
|
||||
}
|
||||
|
||||
func (t *SpawnStatusTool) Parameters() map[string]any {
|
||||
return map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"task_id": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Optional task ID (e.g. \"subagent-1\") to inspect a specific " +
|
||||
"subagent. When omitted, all visible subagents are listed.",
|
||||
},
|
||||
},
|
||||
"required": []string{},
|
||||
}
|
||||
}
|
||||
|
||||
func (t *SpawnStatusTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
|
||||
if t.manager == nil {
|
||||
return ErrorResult("Subagent manager not configured")
|
||||
}
|
||||
|
||||
// Derive the calling conversation's identity so we can scope results to the
|
||||
// current chat only — preventing cross-conversation task leakage in
|
||||
// multi-user deployments.
|
||||
callerChannel := ToolChannel(ctx)
|
||||
callerChatID := ToolChatID(ctx)
|
||||
|
||||
var taskID string
|
||||
if rawTaskID, ok := args["task_id"]; ok && rawTaskID != nil {
|
||||
taskIDStr, ok := rawTaskID.(string)
|
||||
if !ok {
|
||||
return ErrorResult("task_id must be a string")
|
||||
}
|
||||
taskID = strings.TrimSpace(taskIDStr)
|
||||
}
|
||||
|
||||
if taskID != "" {
|
||||
// GetTaskCopy returns a consistent snapshot under the manager lock,
|
||||
// eliminating any data race with the concurrent subagent goroutine.
|
||||
taskCopy, ok := t.manager.GetTaskCopy(taskID)
|
||||
if !ok {
|
||||
return ErrorResult(fmt.Sprintf("No subagent found with task ID: %s", taskID))
|
||||
}
|
||||
|
||||
// Restrict lookup to tasks that belong to this conversation.
|
||||
if callerChannel != "" && taskCopy.OriginChannel != "" && taskCopy.OriginChannel != callerChannel {
|
||||
return ErrorResult(fmt.Sprintf("No subagent found with task ID: %s", taskID))
|
||||
}
|
||||
if callerChatID != "" && taskCopy.OriginChatID != "" && taskCopy.OriginChatID != callerChatID {
|
||||
return ErrorResult(fmt.Sprintf("No subagent found with task ID: %s", taskID))
|
||||
}
|
||||
|
||||
return NewToolResult(spawnStatusFormatTask(&taskCopy))
|
||||
}
|
||||
|
||||
// ListTaskCopies returns consistent snapshots under the manager lock.
|
||||
origTasks := t.manager.ListTaskCopies()
|
||||
if len(origTasks) == 0 {
|
||||
return NewToolResult("No subagents have been spawned yet.")
|
||||
}
|
||||
|
||||
tasks := make([]*SubagentTask, 0, len(origTasks))
|
||||
for i := range origTasks {
|
||||
cpy := &origTasks[i]
|
||||
|
||||
// Filter to tasks that originate from the current conversation only.
|
||||
if callerChannel != "" && cpy.OriginChannel != "" && cpy.OriginChannel != callerChannel {
|
||||
continue
|
||||
}
|
||||
if callerChatID != "" && cpy.OriginChatID != "" && cpy.OriginChatID != callerChatID {
|
||||
continue
|
||||
}
|
||||
|
||||
tasks = append(tasks, cpy)
|
||||
}
|
||||
|
||||
if len(tasks) == 0 {
|
||||
return NewToolResult("No subagents found for this conversation.")
|
||||
}
|
||||
|
||||
// Order by creation time (ascending) so spawning order is preserved.
|
||||
// Fall back to ID string for tasks created in the same millisecond.
|
||||
sort.Slice(tasks, func(i, j int) bool {
|
||||
if tasks[i].Created != tasks[j].Created {
|
||||
return tasks[i].Created < tasks[j].Created
|
||||
}
|
||||
return tasks[i].ID < tasks[j].ID
|
||||
})
|
||||
|
||||
counts := map[string]int{}
|
||||
for _, task := range tasks {
|
||||
counts[task.Status]++
|
||||
}
|
||||
|
||||
var sb strings.Builder
|
||||
sb.WriteString(fmt.Sprintf("Subagent status report (%d total):\n", len(tasks)))
|
||||
for _, status := range []string{"running", "completed", "failed", "canceled"} {
|
||||
if n := counts[status]; n > 0 {
|
||||
label := strings.ToUpper(status[:1]) + status[1:] + ":"
|
||||
sb.WriteString(fmt.Sprintf(" %-10s %d\n", label, n))
|
||||
}
|
||||
}
|
||||
sb.WriteString("\n")
|
||||
|
||||
for _, task := range tasks {
|
||||
sb.WriteString(spawnStatusFormatTask(task))
|
||||
sb.WriteString("\n\n")
|
||||
}
|
||||
|
||||
return NewToolResult(strings.TrimRight(sb.String(), "\n"))
|
||||
}
|
||||
|
||||
// spawnStatusFormatTask renders a single SubagentTask as a human-readable block.
|
||||
func spawnStatusFormatTask(task *SubagentTask) string {
|
||||
var sb strings.Builder
|
||||
|
||||
header := fmt.Sprintf("[%s] status=%s", task.ID, task.Status)
|
||||
if task.Label != "" {
|
||||
header += fmt.Sprintf(" label=%q", task.Label)
|
||||
}
|
||||
if task.AgentID != "" {
|
||||
header += fmt.Sprintf(" agent=%s", task.AgentID)
|
||||
}
|
||||
if task.Created > 0 {
|
||||
created := time.UnixMilli(task.Created).UTC().Format("2006-01-02 15:04:05 UTC")
|
||||
header += fmt.Sprintf(" created=%s", created)
|
||||
}
|
||||
sb.WriteString(header)
|
||||
|
||||
if task.Task != "" {
|
||||
sb.WriteString(fmt.Sprintf("\n task: %s", task.Task))
|
||||
}
|
||||
if task.Result != "" {
|
||||
result := task.Result
|
||||
const maxResultLen = 300
|
||||
runes := []rune(result)
|
||||
if len(runes) > maxResultLen {
|
||||
result = string(runes[:maxResultLen]) + "…"
|
||||
}
|
||||
sb.WriteString(fmt.Sprintf("\n result: %s", result))
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
@@ -0,0 +1,406 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestSpawnStatusTool_Name(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
workspace := t.TempDir()
|
||||
manager := NewSubagentManager(provider, "test-model", workspace)
|
||||
tool := NewSpawnStatusTool(manager)
|
||||
|
||||
if tool.Name() != "spawn_status" {
|
||||
t.Errorf("Expected name 'spawn_status', got '%s'", tool.Name())
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpawnStatusTool_Description(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
workspace := t.TempDir()
|
||||
manager := NewSubagentManager(provider, "test-model", workspace)
|
||||
tool := NewSpawnStatusTool(manager)
|
||||
|
||||
desc := tool.Description()
|
||||
if desc == "" {
|
||||
t.Error("Description should not be empty")
|
||||
}
|
||||
if !strings.Contains(strings.ToLower(desc), "subagent") {
|
||||
t.Errorf("Description should mention 'subagent', got: %s", desc)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpawnStatusTool_Parameters(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
workspace := t.TempDir()
|
||||
manager := NewSubagentManager(provider, "test-model", workspace)
|
||||
tool := NewSpawnStatusTool(manager)
|
||||
|
||||
params := tool.Parameters()
|
||||
if params["type"] != "object" {
|
||||
t.Errorf("Expected type 'object', got: %v", params["type"])
|
||||
}
|
||||
props, ok := params["properties"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("Expected 'properties' to be a map")
|
||||
}
|
||||
if _, hasTaskID := props["task_id"]; !hasTaskID {
|
||||
t.Error("Expected 'task_id' parameter in properties")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpawnStatusTool_NilManager(t *testing.T) {
|
||||
tool := &SpawnStatusTool{manager: nil}
|
||||
result := tool.Execute(context.Background(), map[string]any{})
|
||||
if !result.IsError {
|
||||
t.Error("Expected error result when manager is nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpawnStatusTool_Empty(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
workspace := t.TempDir()
|
||||
manager := NewSubagentManager(provider, "test-model", workspace)
|
||||
tool := NewSpawnStatusTool(manager)
|
||||
|
||||
result := tool.Execute(context.Background(), map[string]any{})
|
||||
if result.IsError {
|
||||
t.Fatalf("Expected success, got error: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "No subagents") {
|
||||
t.Errorf("Expected 'No subagents' message, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpawnStatusTool_ListAll(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
workspace := t.TempDir()
|
||||
manager := NewSubagentManager(provider, "test-model", workspace)
|
||||
|
||||
now := time.Now().UnixMilli()
|
||||
manager.mu.Lock()
|
||||
manager.tasks["subagent-1"] = &SubagentTask{
|
||||
ID: "subagent-1",
|
||||
Task: "Do task A",
|
||||
Label: "task-a",
|
||||
Status: "running",
|
||||
Created: now,
|
||||
}
|
||||
manager.tasks["subagent-2"] = &SubagentTask{
|
||||
ID: "subagent-2",
|
||||
Task: "Do task B",
|
||||
Label: "task-b",
|
||||
Status: "completed",
|
||||
Result: "Done successfully",
|
||||
Created: now,
|
||||
}
|
||||
manager.tasks["subagent-3"] = &SubagentTask{
|
||||
ID: "subagent-3",
|
||||
Task: "Do task C",
|
||||
Status: "failed",
|
||||
Result: "Error: something went wrong",
|
||||
}
|
||||
manager.mu.Unlock()
|
||||
|
||||
tool := NewSpawnStatusTool(manager)
|
||||
result := tool.Execute(context.Background(), map[string]any{})
|
||||
|
||||
if result.IsError {
|
||||
t.Fatalf("Expected success, got error: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
// Summary header
|
||||
if !strings.Contains(result.ForLLM, "3 total") {
|
||||
t.Errorf("Expected total count in header, got: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
// Individual task IDs
|
||||
for _, id := range []string{"subagent-1", "subagent-2", "subagent-3"} {
|
||||
if !strings.Contains(result.ForLLM, id) {
|
||||
t.Errorf("Expected task %s in output, got:\n%s", id, result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
// Status values
|
||||
for _, status := range []string{"running", "completed", "failed"} {
|
||||
if !strings.Contains(result.ForLLM, status) {
|
||||
t.Errorf("Expected status '%s' in output, got:\n%s", status, result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
// Result content
|
||||
if !strings.Contains(result.ForLLM, "Done successfully") {
|
||||
t.Errorf("Expected result text in output, got:\n%s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpawnStatusTool_GetByID(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
|
||||
|
||||
manager.mu.Lock()
|
||||
manager.tasks["subagent-42"] = &SubagentTask{
|
||||
ID: "subagent-42",
|
||||
Task: "Specific task",
|
||||
Label: "my-task",
|
||||
Status: "failed",
|
||||
Result: "Something went wrong",
|
||||
Created: time.Now().UnixMilli(),
|
||||
}
|
||||
manager.mu.Unlock()
|
||||
|
||||
tool := NewSpawnStatusTool(manager)
|
||||
result := tool.Execute(context.Background(), map[string]any{"task_id": "subagent-42"})
|
||||
|
||||
if result.IsError {
|
||||
t.Fatalf("Expected success, got error: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "subagent-42") {
|
||||
t.Errorf("Expected task ID in output, got: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "failed") {
|
||||
t.Errorf("Expected status 'failed' in output, got: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "Something went wrong") {
|
||||
t.Errorf("Expected result text in output, got: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "my-task") {
|
||||
t.Errorf("Expected label in output, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpawnStatusTool_GetByID_NotFound(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
|
||||
tool := NewSpawnStatusTool(manager)
|
||||
|
||||
result := tool.Execute(context.Background(), map[string]any{"task_id": "nonexistent-999"})
|
||||
if !result.IsError {
|
||||
t.Errorf("Expected error for nonexistent task, got: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "nonexistent-999") {
|
||||
t.Errorf("Expected task ID in error message, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpawnStatusTool_TaskID_NonString(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
|
||||
tool := NewSpawnStatusTool(manager)
|
||||
|
||||
for _, badVal := range []any{42, 3.14, true, map[string]any{"x": 1}, []string{"a"}} {
|
||||
result := tool.Execute(context.Background(), map[string]any{"task_id": badVal})
|
||||
if !result.IsError {
|
||||
t.Errorf("Expected error for task_id=%T(%v), got success: %s", badVal, badVal, result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "task_id must be a string") {
|
||||
t.Errorf("Expected type-error message, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpawnStatusTool_ResultTruncation(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
|
||||
|
||||
longResult := strings.Repeat("X", 500)
|
||||
manager.mu.Lock()
|
||||
manager.tasks["subagent-1"] = &SubagentTask{
|
||||
ID: "subagent-1",
|
||||
Task: "Long task",
|
||||
Status: "completed",
|
||||
Result: longResult,
|
||||
}
|
||||
manager.mu.Unlock()
|
||||
|
||||
tool := NewSpawnStatusTool(manager)
|
||||
result := tool.Execute(context.Background(), map[string]any{"task_id": "subagent-1"})
|
||||
|
||||
if result.IsError {
|
||||
t.Fatalf("Unexpected error: %s", result.ForLLM)
|
||||
}
|
||||
// Output should be shorter than the raw result due to truncation
|
||||
if len(result.ForLLM) >= len(longResult) {
|
||||
t.Errorf("Expected result to be truncated, but ForLLM is %d chars", len(result.ForLLM))
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "…") {
|
||||
t.Errorf("Expected truncation indicator '…' in output, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpawnStatusTool_ResultTruncation_Unicode(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
|
||||
|
||||
// Each CJK rune is 3 bytes; 400 runes = 1200 bytes — well over the 300-rune limit.
|
||||
cjkChar := string(rune(0x5b57))
|
||||
longResult := strings.Repeat(cjkChar, 400)
|
||||
manager.mu.Lock()
|
||||
manager.tasks["subagent-1"] = &SubagentTask{
|
||||
ID: "subagent-1",
|
||||
Task: "Unicode task",
|
||||
Status: "completed",
|
||||
Result: longResult,
|
||||
}
|
||||
manager.mu.Unlock()
|
||||
|
||||
tool := NewSpawnStatusTool(manager)
|
||||
result := tool.Execute(context.Background(), map[string]any{"task_id": "subagent-1"})
|
||||
|
||||
if result.IsError {
|
||||
t.Fatalf("Unexpected error: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "…") {
|
||||
t.Errorf("Expected truncation indicator in output")
|
||||
}
|
||||
// The truncated result must be valid UTF-8 (no split rune boundaries).
|
||||
if !strings.Contains(result.ForLLM, cjkChar) {
|
||||
t.Errorf("Expected CJK runes to appear intact in output")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpawnStatusTool_StatusCounts(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
|
||||
|
||||
manager.mu.Lock()
|
||||
for i, status := range []string{"running", "running", "completed", "failed", "canceled"} {
|
||||
id := fmt.Sprintf("subagent-%d", i+1)
|
||||
manager.tasks[id] = &SubagentTask{ID: id, Task: "t", Status: status}
|
||||
}
|
||||
manager.mu.Unlock()
|
||||
|
||||
tool := NewSpawnStatusTool(manager)
|
||||
result := tool.Execute(context.Background(), map[string]any{})
|
||||
|
||||
if result.IsError {
|
||||
t.Fatalf("Unexpected error: %s", result.ForLLM)
|
||||
}
|
||||
// The summary line should mention all statuses that have counts
|
||||
for _, want := range []string{"Running:", "Completed:", "Failed:", "Canceled:"} {
|
||||
if !strings.Contains(result.ForLLM, want) {
|
||||
t.Errorf("Expected %q in summary, got:\n%s", want, result.ForLLM)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpawnStatusTool_SortByCreatedTimestamp(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
|
||||
|
||||
now := time.Now().UnixMilli()
|
||||
manager.mu.Lock()
|
||||
// Intentionally insert with out-of-order IDs and timestamps that reflect
|
||||
// true spawn order: subagent-2 was spawned first, subagent-10 second.
|
||||
manager.tasks["subagent-10"] = &SubagentTask{
|
||||
ID: "subagent-10", Task: "second", Status: "running",
|
||||
Created: now + 1,
|
||||
}
|
||||
manager.tasks["subagent-2"] = &SubagentTask{
|
||||
ID: "subagent-2", Task: "first", Status: "running",
|
||||
Created: now,
|
||||
}
|
||||
manager.mu.Unlock()
|
||||
|
||||
tool := NewSpawnStatusTool(manager)
|
||||
result := tool.Execute(context.Background(), map[string]any{})
|
||||
|
||||
if result.IsError {
|
||||
t.Fatalf("Unexpected error: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
pos2 := strings.Index(result.ForLLM, "subagent-2")
|
||||
pos10 := strings.Index(result.ForLLM, "subagent-10")
|
||||
if pos2 < 0 || pos10 < 0 {
|
||||
t.Fatalf("Both task IDs should appear in output:\n%s", result.ForLLM)
|
||||
}
|
||||
if pos2 > pos10 {
|
||||
t.Errorf("Expected subagent-2 (created first) to appear before subagent-10, but got:\n%s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpawnStatusTool_ChannelFiltering_ListAll(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
|
||||
|
||||
manager.mu.Lock()
|
||||
manager.tasks["subagent-1"] = &SubagentTask{
|
||||
ID: "subagent-1", Task: "mine", Status: "running",
|
||||
OriginChannel: "telegram", OriginChatID: "chat-A",
|
||||
}
|
||||
manager.tasks["subagent-2"] = &SubagentTask{
|
||||
ID: "subagent-2", Task: "other user", Status: "running",
|
||||
OriginChannel: "telegram", OriginChatID: "chat-B",
|
||||
}
|
||||
manager.mu.Unlock()
|
||||
|
||||
tool := NewSpawnStatusTool(manager)
|
||||
|
||||
// Caller is chat-A — should only see subagent-1.
|
||||
ctx := WithToolContext(context.Background(), "telegram", "chat-A")
|
||||
result := tool.Execute(ctx, map[string]any{})
|
||||
|
||||
if result.IsError {
|
||||
t.Fatalf("Unexpected error: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "subagent-1") {
|
||||
t.Errorf("Expected own task in output, got:\n%s", result.ForLLM)
|
||||
}
|
||||
if strings.Contains(result.ForLLM, "subagent-2") {
|
||||
t.Errorf("Should NOT see other chat's task, got:\n%s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpawnStatusTool_ChannelFiltering_GetByID(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
|
||||
|
||||
manager.mu.Lock()
|
||||
manager.tasks["subagent-99"] = &SubagentTask{
|
||||
ID: "subagent-99", Task: "secret", Status: "completed", Result: "private data",
|
||||
OriginChannel: "slack", OriginChatID: "room-Z",
|
||||
}
|
||||
manager.mu.Unlock()
|
||||
|
||||
tool := NewSpawnStatusTool(manager)
|
||||
|
||||
// Different chat trying to look up subagent-99 by ID.
|
||||
ctx := WithToolContext(context.Background(), "slack", "room-OTHER")
|
||||
result := tool.Execute(ctx, map[string]any{"task_id": "subagent-99"})
|
||||
|
||||
if !result.IsError {
|
||||
t.Errorf("Expected error (cross-chat lookup blocked), got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpawnStatusTool_ChannelFiltering_NoContext(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
|
||||
|
||||
manager.mu.Lock()
|
||||
manager.tasks["subagent-1"] = &SubagentTask{
|
||||
ID: "subagent-1", Task: "t", Status: "completed",
|
||||
OriginChannel: "telegram", OriginChatID: "chat-A",
|
||||
}
|
||||
manager.mu.Unlock()
|
||||
|
||||
tool := NewSpawnStatusTool(manager)
|
||||
|
||||
// No ToolContext injected (e.g. a direct programmatic call that bypasses
|
||||
// WithToolContext entirely) — callerChannel and callerChatID are both "".
|
||||
// Note: the normal CLI path uses ProcessDirectWithChannel("cli", "direct"),
|
||||
// which *does* inject a non-empty context; this test covers the case where
|
||||
// no context injection happens at all.
|
||||
// The filter conditions require a non-empty caller value, so all tasks pass through.
|
||||
result := tool.Execute(context.Background(), map[string]any{})
|
||||
if result.IsError {
|
||||
t.Fatalf("Unexpected error: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "subagent-1") {
|
||||
t.Errorf("Expected task visible from no-context caller, got:\n%s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
@@ -255,6 +255,18 @@ func (sm *SubagentManager) GetTask(taskID string) (*SubagentTask, bool) {
|
||||
return task, ok
|
||||
}
|
||||
|
||||
// GetTaskCopy returns a copy of the task with the given ID, taken under the
|
||||
// read lock, so the caller receives a consistent snapshot with no data race.
|
||||
func (sm *SubagentManager) GetTaskCopy(taskID string) (SubagentTask, bool) {
|
||||
sm.mu.RLock()
|
||||
defer sm.mu.RUnlock()
|
||||
task, ok := sm.tasks[taskID]
|
||||
if !ok {
|
||||
return SubagentTask{}, false
|
||||
}
|
||||
return *task, true
|
||||
}
|
||||
|
||||
func (sm *SubagentManager) ListTasks() []*SubagentTask {
|
||||
sm.mu.RLock()
|
||||
defer sm.mu.RUnlock()
|
||||
@@ -266,6 +278,19 @@ func (sm *SubagentManager) ListTasks() []*SubagentTask {
|
||||
return tasks
|
||||
}
|
||||
|
||||
// ListTaskCopies returns value copies of all tasks, taken under the read lock,
|
||||
// so callers receive consistent snapshots with no data race.
|
||||
func (sm *SubagentManager) ListTaskCopies() []SubagentTask {
|
||||
sm.mu.RLock()
|
||||
defer sm.mu.RUnlock()
|
||||
|
||||
copies := make([]SubagentTask, 0, len(sm.tasks))
|
||||
for _, task := range sm.tasks {
|
||||
copies = append(copies, *task)
|
||||
}
|
||||
return copies
|
||||
}
|
||||
|
||||
// SubagentTool executes a subagent task synchronously and returns the result.
|
||||
// It directly calls SubTurnSpawner with Async=false for synchronous execution.
|
||||
type SubagentTool struct {
|
||||
|
||||
Reference in New Issue
Block a user