mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
Merge branch 'main' into feat/markdown-output-format-web-fetch
This commit is contained in:
+53
-20
@@ -20,10 +20,12 @@ type JobExecutor interface {
|
||||
|
||||
// CronTool provides scheduling capabilities for the agent
|
||||
type CronTool struct {
|
||||
cronService *cron.CronService
|
||||
executor JobExecutor
|
||||
msgBus *bus.MessageBus
|
||||
execTool *ExecTool
|
||||
cronService *cron.CronService
|
||||
executor JobExecutor
|
||||
msgBus *bus.MessageBus
|
||||
execTool *ExecTool
|
||||
allowCommand bool
|
||||
execEnabled bool
|
||||
}
|
||||
|
||||
// NewCronTool creates a new CronTool
|
||||
@@ -32,17 +34,32 @@ func NewCronTool(
|
||||
cronService *cron.CronService, executor JobExecutor, msgBus *bus.MessageBus, workspace string, restrict bool,
|
||||
execTimeout time.Duration, config *config.Config,
|
||||
) (*CronTool, error) {
|
||||
execTool, err := NewExecToolWithConfig(workspace, restrict, config)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to configure exec tool: %w", err)
|
||||
allowCommand := true
|
||||
execEnabled := true
|
||||
if config != nil {
|
||||
allowCommand = config.Tools.Cron.AllowCommand
|
||||
execEnabled = config.Tools.Exec.Enabled
|
||||
}
|
||||
|
||||
execTool.SetTimeout(execTimeout)
|
||||
var execTool *ExecTool
|
||||
if execEnabled {
|
||||
var err error
|
||||
execTool, err = NewExecToolWithConfig(workspace, restrict, config)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to configure exec tool: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if execTool != nil {
|
||||
execTool.SetTimeout(execTimeout)
|
||||
}
|
||||
return &CronTool{
|
||||
cronService: cronService,
|
||||
executor: executor,
|
||||
msgBus: msgBus,
|
||||
execTool: execTool,
|
||||
cronService: cronService,
|
||||
executor: executor,
|
||||
msgBus: msgBus,
|
||||
execTool: execTool,
|
||||
allowCommand: allowCommand,
|
||||
execEnabled: execEnabled,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -76,7 +93,7 @@ func (t *CronTool) Parameters() map[string]any {
|
||||
},
|
||||
"command_confirm": map[string]any{
|
||||
"type": "boolean",
|
||||
"description": "Required when using command=true. Must be true to explicitly confirm scheduling a shell command.",
|
||||
"description": "Optional explicit confirmation flag for scheduling a shell command. Command execution must also be enabled via tools.cron.allow_command.",
|
||||
},
|
||||
"at_seconds": map[string]any{
|
||||
"type": "integer",
|
||||
@@ -96,7 +113,7 @@ func (t *CronTool) Parameters() map[string]any {
|
||||
},
|
||||
"deliver": map[string]any{
|
||||
"type": "boolean",
|
||||
"description": "If true, send message directly to channel. If false, let agent process message (for complex tasks). Default: true",
|
||||
"description": "If true, send message directly to channel. If false, let agent process message (for complex tasks). Default: false",
|
||||
},
|
||||
},
|
||||
"required": []string{"action"},
|
||||
@@ -174,22 +191,26 @@ func (t *CronTool) addJob(ctx context.Context, args map[string]any) *ToolResult
|
||||
return ErrorResult("one of at_seconds, every_seconds, or cron_expr is required")
|
||||
}
|
||||
|
||||
// Read deliver parameter, default to true
|
||||
deliver := true
|
||||
// Read deliver parameter, default to false so scheduled tasks execute through the agent
|
||||
deliver := false
|
||||
if d, ok := args["deliver"].(bool); ok {
|
||||
deliver = d
|
||||
}
|
||||
|
||||
// GHSA-pv8c-p6jf-3fpp: command scheduling requires internal channel + explicit confirm.
|
||||
// Non-command reminders (plain messages) remain open to all channels.
|
||||
// GHSA-pv8c-p6jf-3fpp: command scheduling requires internal channel. When
|
||||
// allow_command is disabled, explicit confirmation is required as an override.
|
||||
// Non-command reminders remain open to all channels.
|
||||
command, _ := args["command"].(string)
|
||||
commandConfirm, _ := args["command_confirm"].(bool)
|
||||
if command != "" {
|
||||
if !t.execEnabled {
|
||||
return ErrorResult("command execution is disabled")
|
||||
}
|
||||
if !constants.IsInternalChannel(channel) {
|
||||
return ErrorResult("scheduling command execution is restricted to internal channels")
|
||||
}
|
||||
if !commandConfirm {
|
||||
return ErrorResult("command_confirm=true is required to schedule command execution")
|
||||
if !t.allowCommand && !commandConfirm {
|
||||
return ErrorResult("command_confirm=true is required when allow_command is disabled")
|
||||
}
|
||||
deliver = false
|
||||
}
|
||||
@@ -290,6 +311,18 @@ func (t *CronTool) ExecuteJob(ctx context.Context, job *cron.CronJob) string {
|
||||
|
||||
// Execute command if present
|
||||
if job.Payload.Command != "" {
|
||||
if !t.execEnabled || t.execTool == nil {
|
||||
output := "Error executing scheduled command: command execution is disabled"
|
||||
pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer pubCancel()
|
||||
t.msgBus.PublishOutbound(pubCtx, bus.OutboundMessage{
|
||||
Channel: channel,
|
||||
ChatID: chatID,
|
||||
Content: output,
|
||||
})
|
||||
return "ok"
|
||||
}
|
||||
|
||||
args := map[string]any{
|
||||
"command": job.Payload.Command,
|
||||
"__channel": channel,
|
||||
|
||||
+126
-6
@@ -5,18 +5,18 @@ import (
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/cron"
|
||||
)
|
||||
|
||||
func newTestCronTool(t *testing.T) *CronTool {
|
||||
func newTestCronToolWithConfig(t *testing.T, cfg *config.Config) *CronTool {
|
||||
t.Helper()
|
||||
storePath := filepath.Join(t.TempDir(), "cron.json")
|
||||
cronService := cron.NewCronService(storePath, nil)
|
||||
msgBus := bus.NewMessageBus()
|
||||
cfg := config.DefaultConfig()
|
||||
tool, err := NewCronTool(cronService, nil, msgBus, t.TempDir(), true, 0, cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("NewCronTool() error: %v", err)
|
||||
@@ -24,6 +24,11 @@ func newTestCronTool(t *testing.T) *CronTool {
|
||||
return tool
|
||||
}
|
||||
|
||||
func newTestCronTool(t *testing.T) *CronTool {
|
||||
t.Helper()
|
||||
return newTestCronToolWithConfig(t, config.DefaultConfig())
|
||||
}
|
||||
|
||||
// TestCronTool_CommandBlockedFromRemoteChannel verifies command scheduling is restricted to internal channels
|
||||
func TestCronTool_CommandBlockedFromRemoteChannel(t *testing.T) {
|
||||
tool := newTestCronTool(t)
|
||||
@@ -44,8 +49,7 @@ func TestCronTool_CommandBlockedFromRemoteChannel(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestCronTool_CommandRequiresConfirm verifies command_confirm=true is required
|
||||
func TestCronTool_CommandRequiresConfirm(t *testing.T) {
|
||||
func TestCronTool_CommandDoesNotRequireConfirmByDefault(t *testing.T) {
|
||||
tool := newTestCronTool(t)
|
||||
ctx := WithToolContext(context.Background(), "cli", "direct")
|
||||
result := tool.Execute(ctx, map[string]any{
|
||||
@@ -55,11 +59,79 @@ func TestCronTool_CommandRequiresConfirm(t *testing.T) {
|
||||
"at_seconds": float64(60),
|
||||
})
|
||||
|
||||
if result.IsError {
|
||||
t.Fatalf("expected command scheduling without confirm to succeed by default, got: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "Cron job added") {
|
||||
t.Errorf("expected 'Cron job added', got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCronTool_CommandRequiresConfirmWhenAllowCommandDisabled(t *testing.T) {
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Tools.Cron.AllowCommand = false
|
||||
|
||||
tool := newTestCronToolWithConfig(t, cfg)
|
||||
ctx := WithToolContext(context.Background(), "cli", "direct")
|
||||
result := tool.Execute(ctx, map[string]any{
|
||||
"action": "add",
|
||||
"message": "check disk",
|
||||
"command": "df -h",
|
||||
"at_seconds": float64(60),
|
||||
})
|
||||
|
||||
if !result.IsError {
|
||||
t.Fatal("expected error when command_confirm is missing")
|
||||
t.Fatal("expected command scheduling to require confirm when allow_command is disabled")
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "command_confirm=true") {
|
||||
t.Errorf("expected 'command_confirm=true' message, got: %s", result.ForLLM)
|
||||
t.Errorf("expected command_confirm requirement message, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCronTool_CommandAllowedWithConfirmWhenAllowCommandDisabled(t *testing.T) {
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Tools.Cron.AllowCommand = false
|
||||
|
||||
tool := newTestCronToolWithConfig(t, cfg)
|
||||
ctx := WithToolContext(context.Background(), "cli", "direct")
|
||||
result := tool.Execute(ctx, map[string]any{
|
||||
"action": "add",
|
||||
"message": "check disk",
|
||||
"command": "df -h",
|
||||
"command_confirm": true,
|
||||
"at_seconds": float64(60),
|
||||
})
|
||||
|
||||
if result.IsError {
|
||||
t.Fatalf(
|
||||
"expected command scheduling with confirm to succeed when allow_command is disabled, got: %s",
|
||||
result.ForLLM,
|
||||
)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "Cron job added") {
|
||||
t.Errorf("expected 'Cron job added', got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCronTool_CommandBlockedWhenExecDisabled(t *testing.T) {
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Tools.Exec.Enabled = false
|
||||
|
||||
tool := newTestCronToolWithConfig(t, cfg)
|
||||
ctx := WithToolContext(context.Background(), "cli", "direct")
|
||||
result := tool.Execute(ctx, map[string]any{
|
||||
"action": "add",
|
||||
"message": "check disk",
|
||||
"command": "df -h",
|
||||
"command_confirm": true,
|
||||
"at_seconds": float64(60),
|
||||
})
|
||||
|
||||
if !result.IsError {
|
||||
t.Fatal("expected command scheduling to be blocked when exec is disabled")
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "command execution is disabled") {
|
||||
t.Errorf("expected exec disabled message, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -114,3 +186,51 @@ func TestCronTool_NonCommandJobAllowedFromRemoteChannel(t *testing.T) {
|
||||
t.Fatalf("expected non-command reminder to succeed from remote channel, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCronTool_NonCommandJobDefaultsDeliverToFalse(t *testing.T) {
|
||||
tool := newTestCronTool(t)
|
||||
ctx := WithToolContext(context.Background(), "telegram", "chat-1")
|
||||
result := tool.Execute(ctx, map[string]any{
|
||||
"action": "add",
|
||||
"message": "send me a poem",
|
||||
"at_seconds": float64(600),
|
||||
})
|
||||
|
||||
if result.IsError {
|
||||
t.Fatalf("expected non-command reminder to succeed, got: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
jobs := tool.cronService.ListJobs(false)
|
||||
if len(jobs) != 1 {
|
||||
t.Fatalf("expected 1 job, got %d", len(jobs))
|
||||
}
|
||||
if jobs[0].Payload.Deliver {
|
||||
t.Fatal("expected deliver=false by default for non-command jobs")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCronTool_ExecuteJobPublishesErrorWhenExecDisabled(t *testing.T) {
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Tools.Exec.Enabled = false
|
||||
|
||||
tool := newTestCronToolWithConfig(t, cfg)
|
||||
job := &cron.CronJob{}
|
||||
job.Payload.Channel = "cli"
|
||||
job.Payload.To = "direct"
|
||||
job.Payload.Command = "df -h"
|
||||
|
||||
if got := tool.ExecuteJob(context.Background(), job); got != "ok" {
|
||||
t.Fatalf("ExecuteJob() = %q, want ok", got)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
|
||||
msg, ok := tool.msgBus.SubscribeOutbound(ctx)
|
||||
if !ok {
|
||||
t.Fatal("expected outbound message")
|
||||
}
|
||||
if !strings.Contains(msg.Content, "command execution is disabled") {
|
||||
t.Fatalf("expected exec disabled message, got: %s", msg.Content)
|
||||
}
|
||||
}
|
||||
|
||||
+161
-9
@@ -20,8 +20,7 @@ import (
|
||||
|
||||
const MaxReadFileSize = 64 * 1024 // 64KB limit to avoid context overflow
|
||||
|
||||
// validatePath ensures the given path is within the workspace if restrict is true.
|
||||
func validatePath(path, workspace string, restrict bool) (string, error) {
|
||||
func validatePathWithAllowPaths(path, workspace string, restrict bool, patterns []*regexp.Regexp) (string, error) {
|
||||
if workspace == "" {
|
||||
return path, fmt.Errorf("workspace is not defined")
|
||||
}
|
||||
@@ -42,6 +41,10 @@ func validatePath(path, workspace string, restrict bool) (string, error) {
|
||||
}
|
||||
|
||||
if restrict {
|
||||
if isAllowedPath(absPath, patterns) {
|
||||
return absPath, nil
|
||||
}
|
||||
|
||||
if !isWithinWorkspace(absPath, absWorkspace) {
|
||||
return "", fmt.Errorf("access denied: path is outside the workspace")
|
||||
}
|
||||
@@ -73,6 +76,137 @@ func validatePath(path, workspace string, restrict bool) (string, error) {
|
||||
return absPath, nil
|
||||
}
|
||||
|
||||
func isAllowedPath(path string, patterns []*regexp.Regexp) bool {
|
||||
if len(patterns) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
cleaned := filepath.Clean(path)
|
||||
if !filepath.IsAbs(cleaned) {
|
||||
return false
|
||||
}
|
||||
if !matchesAllowedPath(cleaned, patterns) {
|
||||
return false
|
||||
}
|
||||
|
||||
resolved, err := resolvePathAgainstExistingAncestor(cleaned)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return matchesAllowedPath(resolved, patterns)
|
||||
}
|
||||
|
||||
func matchesAllowedPath(path string, patterns []*regexp.Regexp) bool {
|
||||
cleaned := filepath.Clean(path)
|
||||
for _, pattern := range patterns {
|
||||
if pattern.MatchString(cleaned) {
|
||||
return true
|
||||
}
|
||||
if root, ok := extractAllowedPathRoot(pattern); ok && isWithinAllowedRoot(cleaned, root) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func extractAllowedPathRoot(pattern *regexp.Regexp) (string, bool) {
|
||||
raw := pattern.String()
|
||||
if !strings.HasPrefix(raw, "^") {
|
||||
return "", false
|
||||
}
|
||||
|
||||
literal := strings.TrimPrefix(raw, "^")
|
||||
|
||||
// Recognize the common "directory prefix" form: ^<literal>(?:/|$)
|
||||
literal = strings.TrimSuffix(literal, "(?:/|$)")
|
||||
literal = strings.TrimSuffix(literal, `(?:\\|$)`)
|
||||
|
||||
// Reject patterns that still contain regex operators after removing the
|
||||
// optional anchored-directory suffix. That keeps arbitrary regex behavior
|
||||
// unchanged and only enables normalized prefix matching for literal paths.
|
||||
if containsUnescapedRegexMeta(literal) {
|
||||
return "", false
|
||||
}
|
||||
|
||||
unescaped, ok := unescapeRegexLiteral(literal)
|
||||
if !ok || unescaped == "" {
|
||||
return "", false
|
||||
}
|
||||
|
||||
return filepath.Clean(unescaped), filepath.IsAbs(unescaped)
|
||||
}
|
||||
|
||||
func appendUniquePath(paths []string, path string) []string {
|
||||
for _, existing := range paths {
|
||||
if existing == path {
|
||||
return paths
|
||||
}
|
||||
}
|
||||
return append(paths, path)
|
||||
}
|
||||
|
||||
func containsUnescapedRegexMeta(s string) bool {
|
||||
escaped := false
|
||||
for _, r := range s {
|
||||
if escaped {
|
||||
escaped = false
|
||||
continue
|
||||
}
|
||||
if r == '\\' {
|
||||
escaped = true
|
||||
continue
|
||||
}
|
||||
switch r {
|
||||
case '.', '+', '*', '?', '(', ')', '[', ']', '{', '}', '|':
|
||||
return true
|
||||
}
|
||||
}
|
||||
return escaped
|
||||
}
|
||||
|
||||
func unescapeRegexLiteral(s string) (string, bool) {
|
||||
var b strings.Builder
|
||||
b.Grow(len(s))
|
||||
|
||||
escaped := false
|
||||
for _, r := range s {
|
||||
if escaped {
|
||||
b.WriteRune(r)
|
||||
escaped = false
|
||||
continue
|
||||
}
|
||||
if r == '\\' {
|
||||
escaped = true
|
||||
continue
|
||||
}
|
||||
b.WriteRune(r)
|
||||
}
|
||||
|
||||
if escaped {
|
||||
return "", false
|
||||
}
|
||||
|
||||
return b.String(), true
|
||||
}
|
||||
|
||||
func isWithinAllowedRoot(path, root string) bool {
|
||||
candidate := filepath.Clean(path)
|
||||
allowedVariants := []string{filepath.Clean(root)}
|
||||
|
||||
if resolvedRoot, err := resolvePathAgainstExistingAncestor(root); err == nil {
|
||||
allowedVariants = appendUniquePath(allowedVariants, filepath.Clean(resolvedRoot))
|
||||
}
|
||||
|
||||
for _, allowedRoot := range allowedVariants {
|
||||
if isWithinWorkspace(candidate, allowedRoot) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func resolveExistingAncestor(path string) (string, error) {
|
||||
for current := filepath.Clean(path); ; current = filepath.Dir(current) {
|
||||
if resolved, err := filepath.EvalSymlinks(current); err == nil {
|
||||
@@ -86,9 +220,32 @@ func resolveExistingAncestor(path string) (string, error) {
|
||||
}
|
||||
}
|
||||
|
||||
func resolvePathAgainstExistingAncestor(path string) (string, error) {
|
||||
cleaned := filepath.Clean(path)
|
||||
for current := cleaned; ; current = filepath.Dir(current) {
|
||||
resolved, err := filepath.EvalSymlinks(current)
|
||||
if err == nil {
|
||||
suffix, relErr := filepath.Rel(current, cleaned)
|
||||
if relErr != nil {
|
||||
return "", relErr
|
||||
}
|
||||
if suffix == "." {
|
||||
return filepath.Clean(resolved), nil
|
||||
}
|
||||
return filepath.Clean(filepath.Join(resolved, suffix)), nil
|
||||
}
|
||||
if !os.IsNotExist(err) {
|
||||
return "", err
|
||||
}
|
||||
if filepath.Dir(current) == current {
|
||||
return "", os.ErrNotExist
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func isWithinWorkspace(candidate, workspace string) bool {
|
||||
rel, err := filepath.Rel(filepath.Clean(workspace), filepath.Clean(candidate))
|
||||
return err == nil && filepath.IsLocal(rel)
|
||||
return err == nil && (rel == "." || filepath.IsLocal(rel))
|
||||
}
|
||||
|
||||
type ReadFileTool struct {
|
||||
@@ -625,12 +782,7 @@ type whitelistFs struct {
|
||||
}
|
||||
|
||||
func (w *whitelistFs) matches(path string) bool {
|
||||
for _, p := range w.patterns {
|
||||
if p.MatchString(path) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
return isAllowedPath(path, w.patterns)
|
||||
}
|
||||
|
||||
func (w *whitelistFs) ReadFile(path string) ([]byte, error) {
|
||||
|
||||
@@ -521,6 +521,90 @@ func TestWhitelistFs_AllowsMatchingPaths(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestWhitelistFs_BlocksSymlinkEscapeInAllowedDir(t *testing.T) {
|
||||
workspace := t.TempDir()
|
||||
allowedDir := t.TempDir()
|
||||
secretDir := t.TempDir()
|
||||
secretFile := filepath.Join(secretDir, "secret.txt")
|
||||
if err := os.WriteFile(secretFile, []byte("top secret"), 0o644); err != nil {
|
||||
t.Fatalf("WriteFile(secretFile) error = %v", err)
|
||||
}
|
||||
|
||||
linkPath := filepath.Join(allowedDir, "link_out")
|
||||
if err := os.Symlink(secretDir, linkPath); err != nil {
|
||||
t.Skipf("symlink not supported in this environment: %v", err)
|
||||
}
|
||||
|
||||
patterns := []*regexp.Regexp{regexp.MustCompile(`^` + regexp.QuoteMeta(allowedDir))}
|
||||
tool := NewReadFileTool(workspace, true, MaxReadFileSize, patterns)
|
||||
|
||||
result := tool.Execute(context.Background(), map[string]any{"path": filepath.Join(linkPath, "secret.txt")})
|
||||
if !result.IsError {
|
||||
t.Fatalf("expected symlink escape from allowed dir to be blocked, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWhitelistFs_WriteAllowsNewFileUnderAllowedDir(t *testing.T) {
|
||||
workspace := t.TempDir()
|
||||
rootDir := t.TempDir()
|
||||
allowedDir := filepath.Join(rootDir, "allowed")
|
||||
targetFile := filepath.Join(allowedDir, "nested", "file.txt")
|
||||
|
||||
patterns := []*regexp.Regexp{regexp.MustCompile(`^` + regexp.QuoteMeta(allowedDir))}
|
||||
tool := NewWriteFileTool(workspace, true, patterns)
|
||||
|
||||
result := tool.Execute(context.Background(), map[string]any{
|
||||
"path": targetFile,
|
||||
"content": "outside write",
|
||||
})
|
||||
if result.IsError {
|
||||
t.Fatalf("expected whitelisted write to succeed, got: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(targetFile)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadFile(targetFile) error = %v", err)
|
||||
}
|
||||
if string(data) != "outside write" {
|
||||
t.Fatalf("target file content = %q, want %q", string(data), "outside write")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWhitelistFs_AllowsResolvedAllowedRootAlias(t *testing.T) {
|
||||
workspace := t.TempDir()
|
||||
realDir := t.TempDir()
|
||||
linkParent := t.TempDir()
|
||||
allowedAlias := filepath.Join(linkParent, "allowed-link")
|
||||
|
||||
if err := os.Symlink(realDir, allowedAlias); err != nil {
|
||||
t.Skipf("symlink not supported in this environment: %v", err)
|
||||
}
|
||||
|
||||
targetFile := filepath.Join(allowedAlias, "nested", "alias.txt")
|
||||
if err := os.MkdirAll(filepath.Dir(targetFile), 0o755); err != nil {
|
||||
t.Fatalf("MkdirAll(targetFile dir) error = %v", err)
|
||||
}
|
||||
if err := os.WriteFile(targetFile, []byte("through alias"), 0o644); err != nil {
|
||||
t.Fatalf("WriteFile(targetFile) error = %v", err)
|
||||
}
|
||||
|
||||
patterns := []*regexp.Regexp{
|
||||
regexp.MustCompile(
|
||||
"^" + regexp.QuoteMeta(filepath.Clean(allowedAlias)) +
|
||||
"(?:" + regexp.QuoteMeta(string(os.PathSeparator)) + "|$)",
|
||||
),
|
||||
}
|
||||
tool := NewReadFileTool(workspace, true, MaxReadFileSize, patterns)
|
||||
|
||||
result := tool.Execute(context.Background(), map[string]any{"path": targetFile})
|
||||
if result.IsError {
|
||||
t.Fatalf("expected symlink-backed allowed root to be readable, got: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "through alias") {
|
||||
t.Fatalf("expected file content, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
// TestReadFileTool_ChunkedReading verifies the pagination logic of the tool
|
||||
// by reading a file in multiple chunks using 'offset' and 'length'.
|
||||
func TestReadFileTool_ChunkedReading(t *testing.T) {
|
||||
|
||||
+15
-2
@@ -6,6 +6,7 @@ import (
|
||||
"mime"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/h2non/filetype"
|
||||
@@ -21,20 +22,32 @@ type SendFileTool struct {
|
||||
restrict bool
|
||||
maxFileSize int
|
||||
mediaStore media.MediaStore
|
||||
allowPaths []*regexp.Regexp
|
||||
|
||||
defaultChannel string
|
||||
defaultChatID string
|
||||
}
|
||||
|
||||
func NewSendFileTool(workspace string, restrict bool, maxFileSize int, store media.MediaStore) *SendFileTool {
|
||||
func NewSendFileTool(
|
||||
workspace string,
|
||||
restrict bool,
|
||||
maxFileSize int,
|
||||
store media.MediaStore,
|
||||
allowPaths ...[]*regexp.Regexp,
|
||||
) *SendFileTool {
|
||||
if maxFileSize <= 0 {
|
||||
maxFileSize = config.DefaultMaxMediaSize
|
||||
}
|
||||
var patterns []*regexp.Regexp
|
||||
if len(allowPaths) > 0 {
|
||||
patterns = allowPaths[0]
|
||||
}
|
||||
return &SendFileTool{
|
||||
workspace: workspace,
|
||||
restrict: restrict,
|
||||
maxFileSize: maxFileSize,
|
||||
mediaStore: store,
|
||||
allowPaths: patterns,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -92,7 +105,7 @@ func (t *SendFileTool) Execute(ctx context.Context, args map[string]any) *ToolRe
|
||||
return ErrorResult("media store not configured")
|
||||
}
|
||||
|
||||
resolved, err := validatePath(path, t.workspace, t.restrict)
|
||||
resolved, err := validatePathWithAllowPaths(path, t.workspace, t.restrict, t.allowPaths)
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("invalid path: %v", err))
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
@@ -128,6 +129,44 @@ func TestSendFileTool_CustomFilename(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendFileTool_AllowsWhitelistedMediaTempPath(t *testing.T) {
|
||||
workspace := t.TempDir()
|
||||
mediaDir := media.TempDir()
|
||||
if err := os.MkdirAll(mediaDir, 0o700); err != nil {
|
||||
t.Fatalf("MkdirAll(mediaDir) error = %v", err)
|
||||
}
|
||||
|
||||
testFile, err := os.CreateTemp(mediaDir, "send-file-*.txt")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateTemp(mediaDir) error = %v", err)
|
||||
}
|
||||
testPath := testFile.Name()
|
||||
if _, err := testFile.WriteString("forward me"); err != nil {
|
||||
testFile.Close()
|
||||
t.Fatalf("WriteString(testFile) error = %v", err)
|
||||
}
|
||||
if err := testFile.Close(); err != nil {
|
||||
t.Fatalf("Close(testFile) error = %v", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = os.Remove(testPath) })
|
||||
|
||||
pattern := regexp.MustCompile(
|
||||
"^" + regexp.QuoteMeta(filepath.Clean(mediaDir)) + "(?:" + regexp.QuoteMeta(string(os.PathSeparator)) + "|$)",
|
||||
)
|
||||
|
||||
store := media.NewFileMediaStore()
|
||||
tool := NewSendFileTool(workspace, true, 0, store, []*regexp.Regexp{pattern})
|
||||
tool.SetContext("feishu", "chat123")
|
||||
|
||||
result := tool.Execute(context.Background(), map[string]any{"path": testPath})
|
||||
if result.IsError {
|
||||
t.Fatalf("expected whitelisted temp media file to be sendable, got: %s", result.ForLLM)
|
||||
}
|
||||
if len(result.Media) != 1 {
|
||||
t.Fatalf("expected 1 media ref, got %d", len(result.Media))
|
||||
}
|
||||
}
|
||||
|
||||
func TestDetectMediaType_MagicBytes(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
|
||||
+31
-13
@@ -23,6 +23,7 @@ type ExecTool struct {
|
||||
denyPatterns []*regexp.Regexp
|
||||
allowPatterns []*regexp.Regexp
|
||||
customAllowPatterns []*regexp.Regexp
|
||||
allowedPathPatterns []*regexp.Regexp
|
||||
restrictToWorkspace bool
|
||||
allowRemote bool
|
||||
}
|
||||
@@ -95,14 +96,23 @@ var (
|
||||
}
|
||||
)
|
||||
|
||||
func NewExecTool(workingDir string, restrict bool) (*ExecTool, error) {
|
||||
return NewExecToolWithConfig(workingDir, restrict, nil)
|
||||
func NewExecTool(workingDir string, restrict bool, allowPaths ...[]*regexp.Regexp) (*ExecTool, error) {
|
||||
return NewExecToolWithConfig(workingDir, restrict, nil, allowPaths...)
|
||||
}
|
||||
|
||||
func NewExecToolWithConfig(workingDir string, restrict bool, config *config.Config) (*ExecTool, error) {
|
||||
func NewExecToolWithConfig(
|
||||
workingDir string,
|
||||
restrict bool,
|
||||
config *config.Config,
|
||||
allowPaths ...[]*regexp.Regexp,
|
||||
) (*ExecTool, error) {
|
||||
denyPatterns := make([]*regexp.Regexp, 0)
|
||||
customAllowPatterns := make([]*regexp.Regexp, 0)
|
||||
var allowedPathPatterns []*regexp.Regexp
|
||||
allowRemote := true
|
||||
if len(allowPaths) > 0 {
|
||||
allowedPathPatterns = allowPaths[0]
|
||||
}
|
||||
|
||||
if config != nil {
|
||||
execConfig := config.Tools.Exec
|
||||
@@ -146,6 +156,7 @@ func NewExecToolWithConfig(workingDir string, restrict bool, config *config.Conf
|
||||
denyPatterns: denyPatterns,
|
||||
allowPatterns: nil,
|
||||
customAllowPatterns: customAllowPatterns,
|
||||
allowedPathPatterns: allowedPathPatterns,
|
||||
restrictToWorkspace: restrict,
|
||||
allowRemote: allowRemote,
|
||||
}, nil
|
||||
@@ -198,7 +209,7 @@ func (t *ExecTool) Execute(ctx context.Context, args map[string]any) *ToolResult
|
||||
cwd := t.workingDir
|
||||
if wd, ok := args["working_dir"].(string); ok && wd != "" {
|
||||
if t.restrictToWorkspace && t.workingDir != "" {
|
||||
resolvedWD, err := validatePath(wd, t.workingDir, true)
|
||||
resolvedWD, err := validatePathWithAllowPaths(wd, t.workingDir, true, t.allowedPathPatterns)
|
||||
if err != nil {
|
||||
return ErrorResult("Command blocked by safety guard (" + err.Error() + ")")
|
||||
}
|
||||
@@ -226,16 +237,20 @@ func (t *ExecTool) Execute(ctx context.Context, args map[string]any) *ToolResult
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("Command blocked by safety guard (path resolution failed: %v)", err))
|
||||
}
|
||||
absWorkspace, _ := filepath.Abs(t.workingDir)
|
||||
wsResolved, _ := filepath.EvalSymlinks(absWorkspace)
|
||||
if wsResolved == "" {
|
||||
wsResolved = absWorkspace
|
||||
if isAllowedPath(resolved, t.allowedPathPatterns) {
|
||||
cwd = resolved
|
||||
} else {
|
||||
absWorkspace, _ := filepath.Abs(t.workingDir)
|
||||
wsResolved, _ := filepath.EvalSymlinks(absWorkspace)
|
||||
if wsResolved == "" {
|
||||
wsResolved = absWorkspace
|
||||
}
|
||||
rel, err := filepath.Rel(wsResolved, resolved)
|
||||
if err != nil || !filepath.IsLocal(rel) {
|
||||
return ErrorResult("Command blocked by safety guard (working directory escaped workspace)")
|
||||
}
|
||||
cwd = resolved
|
||||
}
|
||||
rel, err := filepath.Rel(wsResolved, resolved)
|
||||
if err != nil || !filepath.IsLocal(rel) {
|
||||
return ErrorResult("Command blocked by safety guard (working directory escaped workspace)")
|
||||
}
|
||||
cwd = resolved
|
||||
}
|
||||
|
||||
// timeout == 0 means no timeout
|
||||
@@ -412,6 +427,9 @@ func (t *ExecTool) guardCommand(command, cwd string) string {
|
||||
if safePaths[p] {
|
||||
continue
|
||||
}
|
||||
if isAllowedPath(p, t.allowedPathPatterns) {
|
||||
continue
|
||||
}
|
||||
|
||||
rel, err := filepath.Rel(cwdPath, p)
|
||||
if err != nil {
|
||||
|
||||
@@ -0,0 +1,178 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// SpawnStatusTool reports the status of subagents that were spawned via the
|
||||
// spawn tool. It can query a specific task by ID, or list every known task with
|
||||
// a summary count broken-down by status.
|
||||
type SpawnStatusTool struct {
|
||||
manager *SubagentManager
|
||||
}
|
||||
|
||||
// NewSpawnStatusTool creates a SpawnStatusTool backed by the given manager.
|
||||
func NewSpawnStatusTool(manager *SubagentManager) *SpawnStatusTool {
|
||||
return &SpawnStatusTool{manager: manager}
|
||||
}
|
||||
|
||||
func (t *SpawnStatusTool) Name() string {
|
||||
return "spawn_status"
|
||||
}
|
||||
|
||||
func (t *SpawnStatusTool) Description() string {
|
||||
return "Get the status of spawned subagents. " +
|
||||
"Returns a list of all subagents and their current state " +
|
||||
"(running, completed, failed, or canceled), or retrieves details " +
|
||||
"for a specific subagent task when task_id is provided. " +
|
||||
"Results are scoped to the current conversation's channel and chat ID; " +
|
||||
"all tasks are listed only when no channel/chat context is injected " +
|
||||
"(e.g. direct programmatic calls via Execute)."
|
||||
}
|
||||
|
||||
func (t *SpawnStatusTool) Parameters() map[string]any {
|
||||
return map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"task_id": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Optional task ID (e.g. \"subagent-1\") to inspect a specific " +
|
||||
"subagent. When omitted, all visible subagents are listed.",
|
||||
},
|
||||
},
|
||||
"required": []string{},
|
||||
}
|
||||
}
|
||||
|
||||
func (t *SpawnStatusTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
|
||||
if t.manager == nil {
|
||||
return ErrorResult("Subagent manager not configured")
|
||||
}
|
||||
|
||||
// Derive the calling conversation's identity so we can scope results to the
|
||||
// current chat only — preventing cross-conversation task leakage in
|
||||
// multi-user deployments.
|
||||
callerChannel := ToolChannel(ctx)
|
||||
callerChatID := ToolChatID(ctx)
|
||||
|
||||
var taskID string
|
||||
if rawTaskID, ok := args["task_id"]; ok && rawTaskID != nil {
|
||||
taskIDStr, ok := rawTaskID.(string)
|
||||
if !ok {
|
||||
return ErrorResult("task_id must be a string")
|
||||
}
|
||||
taskID = strings.TrimSpace(taskIDStr)
|
||||
}
|
||||
|
||||
if taskID != "" {
|
||||
// GetTaskCopy returns a consistent snapshot under the manager lock,
|
||||
// eliminating any data race with the concurrent subagent goroutine.
|
||||
taskCopy, ok := t.manager.GetTaskCopy(taskID)
|
||||
if !ok {
|
||||
return ErrorResult(fmt.Sprintf("No subagent found with task ID: %s", taskID))
|
||||
}
|
||||
|
||||
// Restrict lookup to tasks that belong to this conversation.
|
||||
if callerChannel != "" && taskCopy.OriginChannel != "" && taskCopy.OriginChannel != callerChannel {
|
||||
return ErrorResult(fmt.Sprintf("No subagent found with task ID: %s", taskID))
|
||||
}
|
||||
if callerChatID != "" && taskCopy.OriginChatID != "" && taskCopy.OriginChatID != callerChatID {
|
||||
return ErrorResult(fmt.Sprintf("No subagent found with task ID: %s", taskID))
|
||||
}
|
||||
|
||||
return NewToolResult(spawnStatusFormatTask(&taskCopy))
|
||||
}
|
||||
|
||||
// ListTaskCopies returns consistent snapshots under the manager lock.
|
||||
origTasks := t.manager.ListTaskCopies()
|
||||
if len(origTasks) == 0 {
|
||||
return NewToolResult("No subagents have been spawned yet.")
|
||||
}
|
||||
|
||||
tasks := make([]*SubagentTask, 0, len(origTasks))
|
||||
for i := range origTasks {
|
||||
cpy := &origTasks[i]
|
||||
|
||||
// Filter to tasks that originate from the current conversation only.
|
||||
if callerChannel != "" && cpy.OriginChannel != "" && cpy.OriginChannel != callerChannel {
|
||||
continue
|
||||
}
|
||||
if callerChatID != "" && cpy.OriginChatID != "" && cpy.OriginChatID != callerChatID {
|
||||
continue
|
||||
}
|
||||
|
||||
tasks = append(tasks, cpy)
|
||||
}
|
||||
|
||||
if len(tasks) == 0 {
|
||||
return NewToolResult("No subagents found for this conversation.")
|
||||
}
|
||||
|
||||
// Order by creation time (ascending) so spawning order is preserved.
|
||||
// Fall back to ID string for tasks created in the same millisecond.
|
||||
sort.Slice(tasks, func(i, j int) bool {
|
||||
if tasks[i].Created != tasks[j].Created {
|
||||
return tasks[i].Created < tasks[j].Created
|
||||
}
|
||||
return tasks[i].ID < tasks[j].ID
|
||||
})
|
||||
|
||||
counts := map[string]int{}
|
||||
for _, task := range tasks {
|
||||
counts[task.Status]++
|
||||
}
|
||||
|
||||
var sb strings.Builder
|
||||
sb.WriteString(fmt.Sprintf("Subagent status report (%d total):\n", len(tasks)))
|
||||
for _, status := range []string{"running", "completed", "failed", "canceled"} {
|
||||
if n := counts[status]; n > 0 {
|
||||
label := strings.ToUpper(status[:1]) + status[1:] + ":"
|
||||
sb.WriteString(fmt.Sprintf(" %-10s %d\n", label, n))
|
||||
}
|
||||
}
|
||||
sb.WriteString("\n")
|
||||
|
||||
for _, task := range tasks {
|
||||
sb.WriteString(spawnStatusFormatTask(task))
|
||||
sb.WriteString("\n\n")
|
||||
}
|
||||
|
||||
return NewToolResult(strings.TrimRight(sb.String(), "\n"))
|
||||
}
|
||||
|
||||
// spawnStatusFormatTask renders a single SubagentTask as a human-readable block.
|
||||
func spawnStatusFormatTask(task *SubagentTask) string {
|
||||
var sb strings.Builder
|
||||
|
||||
header := fmt.Sprintf("[%s] status=%s", task.ID, task.Status)
|
||||
if task.Label != "" {
|
||||
header += fmt.Sprintf(" label=%q", task.Label)
|
||||
}
|
||||
if task.AgentID != "" {
|
||||
header += fmt.Sprintf(" agent=%s", task.AgentID)
|
||||
}
|
||||
if task.Created > 0 {
|
||||
created := time.UnixMilli(task.Created).UTC().Format("2006-01-02 15:04:05 UTC")
|
||||
header += fmt.Sprintf(" created=%s", created)
|
||||
}
|
||||
sb.WriteString(header)
|
||||
|
||||
if task.Task != "" {
|
||||
sb.WriteString(fmt.Sprintf("\n task: %s", task.Task))
|
||||
}
|
||||
if task.Result != "" {
|
||||
result := task.Result
|
||||
const maxResultLen = 300
|
||||
runes := []rune(result)
|
||||
if len(runes) > maxResultLen {
|
||||
result = string(runes[:maxResultLen]) + "…"
|
||||
}
|
||||
sb.WriteString(fmt.Sprintf("\n result: %s", result))
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
@@ -0,0 +1,406 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestSpawnStatusTool_Name(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
workspace := t.TempDir()
|
||||
manager := NewSubagentManager(provider, "test-model", workspace)
|
||||
tool := NewSpawnStatusTool(manager)
|
||||
|
||||
if tool.Name() != "spawn_status" {
|
||||
t.Errorf("Expected name 'spawn_status', got '%s'", tool.Name())
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpawnStatusTool_Description(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
workspace := t.TempDir()
|
||||
manager := NewSubagentManager(provider, "test-model", workspace)
|
||||
tool := NewSpawnStatusTool(manager)
|
||||
|
||||
desc := tool.Description()
|
||||
if desc == "" {
|
||||
t.Error("Description should not be empty")
|
||||
}
|
||||
if !strings.Contains(strings.ToLower(desc), "subagent") {
|
||||
t.Errorf("Description should mention 'subagent', got: %s", desc)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpawnStatusTool_Parameters(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
workspace := t.TempDir()
|
||||
manager := NewSubagentManager(provider, "test-model", workspace)
|
||||
tool := NewSpawnStatusTool(manager)
|
||||
|
||||
params := tool.Parameters()
|
||||
if params["type"] != "object" {
|
||||
t.Errorf("Expected type 'object', got: %v", params["type"])
|
||||
}
|
||||
props, ok := params["properties"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("Expected 'properties' to be a map")
|
||||
}
|
||||
if _, hasTaskID := props["task_id"]; !hasTaskID {
|
||||
t.Error("Expected 'task_id' parameter in properties")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpawnStatusTool_NilManager(t *testing.T) {
|
||||
tool := &SpawnStatusTool{manager: nil}
|
||||
result := tool.Execute(context.Background(), map[string]any{})
|
||||
if !result.IsError {
|
||||
t.Error("Expected error result when manager is nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpawnStatusTool_Empty(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
workspace := t.TempDir()
|
||||
manager := NewSubagentManager(provider, "test-model", workspace)
|
||||
tool := NewSpawnStatusTool(manager)
|
||||
|
||||
result := tool.Execute(context.Background(), map[string]any{})
|
||||
if result.IsError {
|
||||
t.Fatalf("Expected success, got error: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "No subagents") {
|
||||
t.Errorf("Expected 'No subagents' message, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpawnStatusTool_ListAll(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
workspace := t.TempDir()
|
||||
manager := NewSubagentManager(provider, "test-model", workspace)
|
||||
|
||||
now := time.Now().UnixMilli()
|
||||
manager.mu.Lock()
|
||||
manager.tasks["subagent-1"] = &SubagentTask{
|
||||
ID: "subagent-1",
|
||||
Task: "Do task A",
|
||||
Label: "task-a",
|
||||
Status: "running",
|
||||
Created: now,
|
||||
}
|
||||
manager.tasks["subagent-2"] = &SubagentTask{
|
||||
ID: "subagent-2",
|
||||
Task: "Do task B",
|
||||
Label: "task-b",
|
||||
Status: "completed",
|
||||
Result: "Done successfully",
|
||||
Created: now,
|
||||
}
|
||||
manager.tasks["subagent-3"] = &SubagentTask{
|
||||
ID: "subagent-3",
|
||||
Task: "Do task C",
|
||||
Status: "failed",
|
||||
Result: "Error: something went wrong",
|
||||
}
|
||||
manager.mu.Unlock()
|
||||
|
||||
tool := NewSpawnStatusTool(manager)
|
||||
result := tool.Execute(context.Background(), map[string]any{})
|
||||
|
||||
if result.IsError {
|
||||
t.Fatalf("Expected success, got error: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
// Summary header
|
||||
if !strings.Contains(result.ForLLM, "3 total") {
|
||||
t.Errorf("Expected total count in header, got: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
// Individual task IDs
|
||||
for _, id := range []string{"subagent-1", "subagent-2", "subagent-3"} {
|
||||
if !strings.Contains(result.ForLLM, id) {
|
||||
t.Errorf("Expected task %s in output, got:\n%s", id, result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
// Status values
|
||||
for _, status := range []string{"running", "completed", "failed"} {
|
||||
if !strings.Contains(result.ForLLM, status) {
|
||||
t.Errorf("Expected status '%s' in output, got:\n%s", status, result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
// Result content
|
||||
if !strings.Contains(result.ForLLM, "Done successfully") {
|
||||
t.Errorf("Expected result text in output, got:\n%s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpawnStatusTool_GetByID(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
|
||||
|
||||
manager.mu.Lock()
|
||||
manager.tasks["subagent-42"] = &SubagentTask{
|
||||
ID: "subagent-42",
|
||||
Task: "Specific task",
|
||||
Label: "my-task",
|
||||
Status: "failed",
|
||||
Result: "Something went wrong",
|
||||
Created: time.Now().UnixMilli(),
|
||||
}
|
||||
manager.mu.Unlock()
|
||||
|
||||
tool := NewSpawnStatusTool(manager)
|
||||
result := tool.Execute(context.Background(), map[string]any{"task_id": "subagent-42"})
|
||||
|
||||
if result.IsError {
|
||||
t.Fatalf("Expected success, got error: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "subagent-42") {
|
||||
t.Errorf("Expected task ID in output, got: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "failed") {
|
||||
t.Errorf("Expected status 'failed' in output, got: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "Something went wrong") {
|
||||
t.Errorf("Expected result text in output, got: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "my-task") {
|
||||
t.Errorf("Expected label in output, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpawnStatusTool_GetByID_NotFound(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
|
||||
tool := NewSpawnStatusTool(manager)
|
||||
|
||||
result := tool.Execute(context.Background(), map[string]any{"task_id": "nonexistent-999"})
|
||||
if !result.IsError {
|
||||
t.Errorf("Expected error for nonexistent task, got: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "nonexistent-999") {
|
||||
t.Errorf("Expected task ID in error message, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpawnStatusTool_TaskID_NonString(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
|
||||
tool := NewSpawnStatusTool(manager)
|
||||
|
||||
for _, badVal := range []any{42, 3.14, true, map[string]any{"x": 1}, []string{"a"}} {
|
||||
result := tool.Execute(context.Background(), map[string]any{"task_id": badVal})
|
||||
if !result.IsError {
|
||||
t.Errorf("Expected error for task_id=%T(%v), got success: %s", badVal, badVal, result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "task_id must be a string") {
|
||||
t.Errorf("Expected type-error message, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpawnStatusTool_ResultTruncation(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
|
||||
|
||||
longResult := strings.Repeat("X", 500)
|
||||
manager.mu.Lock()
|
||||
manager.tasks["subagent-1"] = &SubagentTask{
|
||||
ID: "subagent-1",
|
||||
Task: "Long task",
|
||||
Status: "completed",
|
||||
Result: longResult,
|
||||
}
|
||||
manager.mu.Unlock()
|
||||
|
||||
tool := NewSpawnStatusTool(manager)
|
||||
result := tool.Execute(context.Background(), map[string]any{"task_id": "subagent-1"})
|
||||
|
||||
if result.IsError {
|
||||
t.Fatalf("Unexpected error: %s", result.ForLLM)
|
||||
}
|
||||
// Output should be shorter than the raw result due to truncation
|
||||
if len(result.ForLLM) >= len(longResult) {
|
||||
t.Errorf("Expected result to be truncated, but ForLLM is %d chars", len(result.ForLLM))
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "…") {
|
||||
t.Errorf("Expected truncation indicator '…' in output, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpawnStatusTool_ResultTruncation_Unicode(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
|
||||
|
||||
// Each CJK rune is 3 bytes; 400 runes = 1200 bytes — well over the 300-rune limit.
|
||||
cjkChar := string(rune(0x5b57))
|
||||
longResult := strings.Repeat(cjkChar, 400)
|
||||
manager.mu.Lock()
|
||||
manager.tasks["subagent-1"] = &SubagentTask{
|
||||
ID: "subagent-1",
|
||||
Task: "Unicode task",
|
||||
Status: "completed",
|
||||
Result: longResult,
|
||||
}
|
||||
manager.mu.Unlock()
|
||||
|
||||
tool := NewSpawnStatusTool(manager)
|
||||
result := tool.Execute(context.Background(), map[string]any{"task_id": "subagent-1"})
|
||||
|
||||
if result.IsError {
|
||||
t.Fatalf("Unexpected error: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "…") {
|
||||
t.Errorf("Expected truncation indicator in output")
|
||||
}
|
||||
// The truncated result must be valid UTF-8 (no split rune boundaries).
|
||||
if !strings.Contains(result.ForLLM, cjkChar) {
|
||||
t.Errorf("Expected CJK runes to appear intact in output")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpawnStatusTool_StatusCounts(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
|
||||
|
||||
manager.mu.Lock()
|
||||
for i, status := range []string{"running", "running", "completed", "failed", "canceled"} {
|
||||
id := fmt.Sprintf("subagent-%d", i+1)
|
||||
manager.tasks[id] = &SubagentTask{ID: id, Task: "t", Status: status}
|
||||
}
|
||||
manager.mu.Unlock()
|
||||
|
||||
tool := NewSpawnStatusTool(manager)
|
||||
result := tool.Execute(context.Background(), map[string]any{})
|
||||
|
||||
if result.IsError {
|
||||
t.Fatalf("Unexpected error: %s", result.ForLLM)
|
||||
}
|
||||
// The summary line should mention all statuses that have counts
|
||||
for _, want := range []string{"Running:", "Completed:", "Failed:", "Canceled:"} {
|
||||
if !strings.Contains(result.ForLLM, want) {
|
||||
t.Errorf("Expected %q in summary, got:\n%s", want, result.ForLLM)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpawnStatusTool_SortByCreatedTimestamp(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
|
||||
|
||||
now := time.Now().UnixMilli()
|
||||
manager.mu.Lock()
|
||||
// Intentionally insert with out-of-order IDs and timestamps that reflect
|
||||
// true spawn order: subagent-2 was spawned first, subagent-10 second.
|
||||
manager.tasks["subagent-10"] = &SubagentTask{
|
||||
ID: "subagent-10", Task: "second", Status: "running",
|
||||
Created: now + 1,
|
||||
}
|
||||
manager.tasks["subagent-2"] = &SubagentTask{
|
||||
ID: "subagent-2", Task: "first", Status: "running",
|
||||
Created: now,
|
||||
}
|
||||
manager.mu.Unlock()
|
||||
|
||||
tool := NewSpawnStatusTool(manager)
|
||||
result := tool.Execute(context.Background(), map[string]any{})
|
||||
|
||||
if result.IsError {
|
||||
t.Fatalf("Unexpected error: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
pos2 := strings.Index(result.ForLLM, "subagent-2")
|
||||
pos10 := strings.Index(result.ForLLM, "subagent-10")
|
||||
if pos2 < 0 || pos10 < 0 {
|
||||
t.Fatalf("Both task IDs should appear in output:\n%s", result.ForLLM)
|
||||
}
|
||||
if pos2 > pos10 {
|
||||
t.Errorf("Expected subagent-2 (created first) to appear before subagent-10, but got:\n%s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpawnStatusTool_ChannelFiltering_ListAll(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
|
||||
|
||||
manager.mu.Lock()
|
||||
manager.tasks["subagent-1"] = &SubagentTask{
|
||||
ID: "subagent-1", Task: "mine", Status: "running",
|
||||
OriginChannel: "telegram", OriginChatID: "chat-A",
|
||||
}
|
||||
manager.tasks["subagent-2"] = &SubagentTask{
|
||||
ID: "subagent-2", Task: "other user", Status: "running",
|
||||
OriginChannel: "telegram", OriginChatID: "chat-B",
|
||||
}
|
||||
manager.mu.Unlock()
|
||||
|
||||
tool := NewSpawnStatusTool(manager)
|
||||
|
||||
// Caller is chat-A — should only see subagent-1.
|
||||
ctx := WithToolContext(context.Background(), "telegram", "chat-A")
|
||||
result := tool.Execute(ctx, map[string]any{})
|
||||
|
||||
if result.IsError {
|
||||
t.Fatalf("Unexpected error: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "subagent-1") {
|
||||
t.Errorf("Expected own task in output, got:\n%s", result.ForLLM)
|
||||
}
|
||||
if strings.Contains(result.ForLLM, "subagent-2") {
|
||||
t.Errorf("Should NOT see other chat's task, got:\n%s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpawnStatusTool_ChannelFiltering_GetByID(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
|
||||
|
||||
manager.mu.Lock()
|
||||
manager.tasks["subagent-99"] = &SubagentTask{
|
||||
ID: "subagent-99", Task: "secret", Status: "completed", Result: "private data",
|
||||
OriginChannel: "slack", OriginChatID: "room-Z",
|
||||
}
|
||||
manager.mu.Unlock()
|
||||
|
||||
tool := NewSpawnStatusTool(manager)
|
||||
|
||||
// Different chat trying to look up subagent-99 by ID.
|
||||
ctx := WithToolContext(context.Background(), "slack", "room-OTHER")
|
||||
result := tool.Execute(ctx, map[string]any{"task_id": "subagent-99"})
|
||||
|
||||
if !result.IsError {
|
||||
t.Errorf("Expected error (cross-chat lookup blocked), got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpawnStatusTool_ChannelFiltering_NoContext(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
|
||||
|
||||
manager.mu.Lock()
|
||||
manager.tasks["subagent-1"] = &SubagentTask{
|
||||
ID: "subagent-1", Task: "t", Status: "completed",
|
||||
OriginChannel: "telegram", OriginChatID: "chat-A",
|
||||
}
|
||||
manager.mu.Unlock()
|
||||
|
||||
tool := NewSpawnStatusTool(manager)
|
||||
|
||||
// No ToolContext injected (e.g. a direct programmatic call that bypasses
|
||||
// WithToolContext entirely) — callerChannel and callerChatID are both "".
|
||||
// Note: the normal CLI path uses ProcessDirectWithChannel("cli", "direct"),
|
||||
// which *does* inject a non-empty context; this test covers the case where
|
||||
// no context injection happens at all.
|
||||
// The filter conditions require a non-empty caller value, so all tasks pass through.
|
||||
result := tool.Execute(context.Background(), map[string]any{})
|
||||
if result.IsError {
|
||||
t.Fatalf("Unexpected error: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "subagent-1") {
|
||||
t.Errorf("Expected task visible from no-context caller, got:\n%s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
+25
-3
@@ -109,9 +109,6 @@ func (sm *SubagentManager) Spawn(
|
||||
}
|
||||
|
||||
func (sm *SubagentManager) runTask(ctx context.Context, task *SubagentTask, callback AsyncCallback) {
|
||||
task.Status = "running"
|
||||
task.Created = time.Now().UnixMilli()
|
||||
|
||||
// Build system prompt for subagent
|
||||
systemPrompt := `You are a subagent. Complete the given task independently and report the result.
|
||||
You have access to tools - use them as needed to complete your task.
|
||||
@@ -219,6 +216,18 @@ func (sm *SubagentManager) GetTask(taskID string) (*SubagentTask, bool) {
|
||||
return task, ok
|
||||
}
|
||||
|
||||
// GetTaskCopy returns a copy of the task with the given ID, taken under the
|
||||
// read lock, so the caller receives a consistent snapshot with no data race.
|
||||
func (sm *SubagentManager) GetTaskCopy(taskID string) (SubagentTask, bool) {
|
||||
sm.mu.RLock()
|
||||
defer sm.mu.RUnlock()
|
||||
task, ok := sm.tasks[taskID]
|
||||
if !ok {
|
||||
return SubagentTask{}, false
|
||||
}
|
||||
return *task, true
|
||||
}
|
||||
|
||||
func (sm *SubagentManager) ListTasks() []*SubagentTask {
|
||||
sm.mu.RLock()
|
||||
defer sm.mu.RUnlock()
|
||||
@@ -230,6 +239,19 @@ func (sm *SubagentManager) ListTasks() []*SubagentTask {
|
||||
return tasks
|
||||
}
|
||||
|
||||
// ListTaskCopies returns value copies of all tasks, taken under the read lock,
|
||||
// so callers receive consistent snapshots with no data race.
|
||||
func (sm *SubagentManager) ListTaskCopies() []SubagentTask {
|
||||
sm.mu.RLock()
|
||||
defer sm.mu.RUnlock()
|
||||
|
||||
copies := make([]SubagentTask, 0, len(sm.tasks))
|
||||
for _, task := range sm.tasks {
|
||||
copies = append(copies, *task)
|
||||
}
|
||||
return copies
|
||||
}
|
||||
|
||||
// SubagentTool executes a subagent task synchronously and returns the result.
|
||||
// Unlike SpawnTool which runs tasks asynchronously, SubagentTool waits for completion
|
||||
// and returns the result directly in the ToolResult.
|
||||
|
||||
+95
-10
@@ -780,11 +780,17 @@ type WebFetchTool struct {
|
||||
client *http.Client
|
||||
format string
|
||||
fetchLimitBytes int64
|
||||
whitelist *privateHostWhitelist
|
||||
}
|
||||
|
||||
type privateHostWhitelist struct {
|
||||
exact map[string]struct{}
|
||||
cidrs []*net.IPNet
|
||||
}
|
||||
|
||||
func NewWebFetchTool(maxChars int, format string, fetchLimitBytes int64) (*WebFetchTool, error) {
|
||||
// createHTTPClient cannot fail with an empty proxy string.
|
||||
return NewWebFetchToolWithProxy(maxChars, "", format, fetchLimitBytes)
|
||||
return NewWebFetchToolWithProxy(maxChars, "", format, fetchLimitBytes, nil)
|
||||
}
|
||||
|
||||
// allowPrivateWebFetchHosts controls whether loopback/private hosts are allowed.
|
||||
@@ -792,9 +798,22 @@ func NewWebFetchTool(maxChars int, format string, fetchLimitBytes int64) (*WebFe
|
||||
var allowPrivateWebFetchHosts atomic.Bool
|
||||
|
||||
func NewWebFetchToolWithProxy(maxChars int, proxy string, format string, fetchLimitBytes int64) (*WebFetchTool, error) {
|
||||
return NewWebFetchToolWithConfig(maxChars, proxy, fetchLimitBytes, nil)
|
||||
}
|
||||
|
||||
func NewWebFetchToolWithConfig(
|
||||
maxChars int,
|
||||
proxy string,
|
||||
fetchLimitBytes int64,
|
||||
privateHostWhitelist []string,
|
||||
) (*WebFetchTool, error) {
|
||||
if maxChars <= 0 {
|
||||
maxChars = defaultMaxChars
|
||||
}
|
||||
whitelist, err := newPrivateHostWhitelist(privateHostWhitelist)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse web fetch private host whitelist: %w", err)
|
||||
}
|
||||
client, err := utils.CreateHTTPClient(proxy, fetchTimeout)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create HTTP client for web fetch: %w", err)
|
||||
@@ -804,13 +823,13 @@ func NewWebFetchToolWithProxy(maxChars int, proxy string, format string, fetchLi
|
||||
Timeout: 15 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
}
|
||||
transport.DialContext = newSafeDialContext(dialer)
|
||||
transport.DialContext = newSafeDialContext(dialer, whitelist)
|
||||
}
|
||||
client.CheckRedirect = func(req *http.Request, via []*http.Request) error {
|
||||
if len(via) >= maxRedirects {
|
||||
return fmt.Errorf("stopped after %d redirects", maxRedirects)
|
||||
}
|
||||
if isObviousPrivateHost(req.URL.Hostname()) {
|
||||
if isObviousPrivateHost(req.URL.Hostname(), whitelist) {
|
||||
return fmt.Errorf("redirect target is private or local network host")
|
||||
}
|
||||
return nil
|
||||
@@ -824,6 +843,7 @@ func NewWebFetchToolWithProxy(maxChars int, proxy string, format string, fetchLi
|
||||
client: client,
|
||||
format: format,
|
||||
fetchLimitBytes: fetchLimitBytes,
|
||||
whitelist: whitelist,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -875,7 +895,7 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]any) *ToolRe
|
||||
// Lightweight pre-flight: block obvious localhost/literal-IP without DNS resolution.
|
||||
// The real SSRF guard is newSafeDialContext at connect time.
|
||||
hostname := parsedURL.Hostname()
|
||||
if isObviousPrivateHost(hostname) {
|
||||
if isObviousPrivateHost(hostname, t.whitelist) {
|
||||
return ErrorResult("fetching private or local network hosts is not allowed")
|
||||
}
|
||||
|
||||
@@ -1019,7 +1039,10 @@ func (t *WebFetchTool) extractText(htmlContent string) string {
|
||||
|
||||
// newSafeDialContext re-resolves DNS at connect time to mitigate DNS rebinding (TOCTOU)
|
||||
// where a hostname resolves to a public IP during pre-flight but a private IP at connect time.
|
||||
func newSafeDialContext(dialer *net.Dialer) func(context.Context, string, string) (net.Conn, error) {
|
||||
func newSafeDialContext(
|
||||
dialer *net.Dialer,
|
||||
whitelist *privateHostWhitelist,
|
||||
) func(context.Context, string, string) (net.Conn, error) {
|
||||
return func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
if allowPrivateWebFetchHosts.Load() {
|
||||
return dialer.DialContext(ctx, network, address)
|
||||
@@ -1034,7 +1057,7 @@ func newSafeDialContext(dialer *net.Dialer) func(context.Context, string, string
|
||||
}
|
||||
|
||||
if ip := net.ParseIP(host); ip != nil {
|
||||
if isPrivateOrRestrictedIP(ip) {
|
||||
if shouldBlockPrivateIP(ip, whitelist) {
|
||||
return nil, fmt.Errorf("blocked private or local target: %s", host)
|
||||
}
|
||||
return dialer.DialContext(ctx, network, net.JoinHostPort(ip.String(), port))
|
||||
@@ -1048,7 +1071,7 @@ func newSafeDialContext(dialer *net.Dialer) func(context.Context, string, string
|
||||
attempted := 0
|
||||
var lastErr error
|
||||
for _, ipAddr := range ipAddrs {
|
||||
if isPrivateOrRestrictedIP(ipAddr.IP) {
|
||||
if shouldBlockPrivateIP(ipAddr.IP, whitelist) {
|
||||
continue
|
||||
}
|
||||
attempted++
|
||||
@@ -1060,7 +1083,7 @@ func newSafeDialContext(dialer *net.Dialer) func(context.Context, string, string
|
||||
}
|
||||
|
||||
if attempted == 0 {
|
||||
return nil, fmt.Errorf("all resolved addresses for %s are private or restricted", host)
|
||||
return nil, fmt.Errorf("all resolved addresses for %s are private, restricted, or not whitelisted", host)
|
||||
}
|
||||
if lastErr != nil {
|
||||
return nil, fmt.Errorf("failed connecting to public addresses for %s: %w", host, lastErr)
|
||||
@@ -1069,10 +1092,72 @@ func newSafeDialContext(dialer *net.Dialer) func(context.Context, string, string
|
||||
}
|
||||
}
|
||||
|
||||
func newPrivateHostWhitelist(entries []string) (*privateHostWhitelist, error) {
|
||||
if len(entries) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
whitelist := &privateHostWhitelist{
|
||||
exact: make(map[string]struct{}),
|
||||
cidrs: make([]*net.IPNet, 0, len(entries)),
|
||||
}
|
||||
for _, entry := range entries {
|
||||
entry = strings.TrimSpace(entry)
|
||||
if entry == "" {
|
||||
continue
|
||||
}
|
||||
if ip := net.ParseIP(entry); ip != nil {
|
||||
whitelist.exact[normalizeWhitelistIP(ip).String()] = struct{}{}
|
||||
continue
|
||||
}
|
||||
_, network, err := net.ParseCIDR(entry)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid entry %q: expected IP or CIDR", entry)
|
||||
}
|
||||
whitelist.cidrs = append(whitelist.cidrs, network)
|
||||
}
|
||||
|
||||
if len(whitelist.exact) == 0 && len(whitelist.cidrs) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
return whitelist, nil
|
||||
}
|
||||
|
||||
func (w *privateHostWhitelist) Contains(ip net.IP) bool {
|
||||
if w == nil || ip == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
normalized := normalizeWhitelistIP(ip)
|
||||
if _, ok := w.exact[normalized.String()]; ok {
|
||||
return true
|
||||
}
|
||||
for _, network := range w.cidrs {
|
||||
if network.Contains(normalized) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func normalizeWhitelistIP(ip net.IP) net.IP {
|
||||
if ip == nil {
|
||||
return nil
|
||||
}
|
||||
if ip4 := ip.To4(); ip4 != nil {
|
||||
return ip4
|
||||
}
|
||||
return ip
|
||||
}
|
||||
|
||||
func shouldBlockPrivateIP(ip net.IP, whitelist *privateHostWhitelist) bool {
|
||||
return isPrivateOrRestrictedIP(ip) && !whitelist.Contains(ip)
|
||||
}
|
||||
|
||||
// isObviousPrivateHost performs a lightweight, no-DNS check for obviously private hosts.
|
||||
// It catches localhost, literal private IPs, and empty hosts. It does NOT resolve DNS —
|
||||
// the real SSRF guard is newSafeDialContext which checks IPs at connect time.
|
||||
func isObviousPrivateHost(host string) bool {
|
||||
func isObviousPrivateHost(host string, whitelist *privateHostWhitelist) bool {
|
||||
if allowPrivateWebFetchHosts.Load() {
|
||||
return false
|
||||
}
|
||||
@@ -1088,7 +1173,7 @@ func isObviousPrivateHost(host string) bool {
|
||||
}
|
||||
|
||||
if ip := net.ParseIP(h); ip != nil {
|
||||
return isPrivateOrRestrictedIP(ip)
|
||||
return shouldBlockPrivateIP(ip, whitelist)
|
||||
}
|
||||
|
||||
return false
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
)
|
||||
@@ -425,6 +426,29 @@ func withPrivateWebFetchHostsAllowed(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func serverHostAndPort(t *testing.T, rawURL string) (string, string) {
|
||||
t.Helper()
|
||||
hostPort := strings.TrimPrefix(rawURL, "http://")
|
||||
hostPort = strings.TrimPrefix(hostPort, "https://")
|
||||
host, port, err := net.SplitHostPort(hostPort)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to split host/port from %q: %v", rawURL, err)
|
||||
}
|
||||
return host, port
|
||||
}
|
||||
|
||||
func singleHostCIDR(t *testing.T, host string) string {
|
||||
t.Helper()
|
||||
ip := net.ParseIP(host)
|
||||
if ip == nil {
|
||||
t.Fatalf("failed to parse IP %q", host)
|
||||
}
|
||||
if ip.To4() != nil {
|
||||
return ip.String() + "/32"
|
||||
}
|
||||
return ip.String() + "/128"
|
||||
}
|
||||
|
||||
func TestWebTool_WebFetch_PrivateHostBlocked(t *testing.T) {
|
||||
tool, err := NewWebFetchTool(50000, format, testFetchLimit)
|
||||
if err != nil {
|
||||
@@ -443,6 +467,56 @@ func TestWebTool_WebFetch_PrivateHostBlocked(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebTool_WebFetch_PrivateHostAllowedByExactWhitelist(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("exact whitelist ok"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
host, _ := serverHostAndPort(t, server.URL)
|
||||
tool, err := NewWebFetchToolWithConfig(50000, "", testFetchLimit, []string{host})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create web fetch tool: %v", err)
|
||||
}
|
||||
|
||||
result := tool.Execute(context.Background(), map[string]any{
|
||||
"url": server.URL,
|
||||
})
|
||||
if result.IsError {
|
||||
t.Fatalf("expected success for exact whitelisted private IP, got %q", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "exact whitelist ok") {
|
||||
t.Fatalf("expected fetched content, got %q", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebTool_WebFetch_PrivateHostAllowedByCIDRWhitelist(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("cidr whitelist ok"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
host, _ := serverHostAndPort(t, server.URL)
|
||||
tool, err := NewWebFetchToolWithConfig(50000, "", testFetchLimit, []string{singleHostCIDR(t, host)})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create web fetch tool: %v", err)
|
||||
}
|
||||
|
||||
result := tool.Execute(context.Background(), map[string]any{
|
||||
"url": server.URL,
|
||||
})
|
||||
if result.IsError {
|
||||
t.Fatalf("expected success for CIDR-whitelisted private IP, got %q", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "cidr whitelist ok") {
|
||||
t.Fatalf("expected fetched content, got %q", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebTool_WebFetch_PrivateHostAllowedForTests(t *testing.T) {
|
||||
withPrivateWebFetchHostsAllowed(t)
|
||||
|
||||
@@ -572,6 +646,69 @@ func TestWebFetch_RedirectToPrivateBlocked(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewSafeDialContext_BlocksPrivateDNSResolutionWithoutWhitelist(t *testing.T) {
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to listen on loopback: %v", err)
|
||||
}
|
||||
defer listener.Close()
|
||||
|
||||
_, port, err := net.SplitHostPort(listener.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatalf("failed to split listener address: %v", err)
|
||||
}
|
||||
|
||||
dialContext := newSafeDialContext(&net.Dialer{Timeout: time.Second}, nil)
|
||||
_, err = dialContext(context.Background(), "tcp", net.JoinHostPort("localhost", port))
|
||||
if err == nil {
|
||||
t.Fatal("expected localhost DNS resolution to be blocked without whitelist")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "private") && !strings.Contains(err.Error(), "whitelisted") {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewSafeDialContext_AllowsWhitelistedPrivateDNSResolution(t *testing.T) {
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to listen on loopback: %v", err)
|
||||
}
|
||||
defer listener.Close()
|
||||
|
||||
accepted := make(chan struct{}, 1)
|
||||
go func() {
|
||||
conn, acceptErr := listener.Accept()
|
||||
if acceptErr != nil {
|
||||
return
|
||||
}
|
||||
conn.Close()
|
||||
accepted <- struct{}{}
|
||||
}()
|
||||
|
||||
_, port, err := net.SplitHostPort(listener.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatalf("failed to split listener address: %v", err)
|
||||
}
|
||||
|
||||
whitelist, err := newPrivateHostWhitelist([]string{"127.0.0.0/8"})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse whitelist: %v", err)
|
||||
}
|
||||
|
||||
dialContext := newSafeDialContext(&net.Dialer{Timeout: time.Second}, whitelist)
|
||||
conn, err := dialContext(context.Background(), "tcp", net.JoinHostPort("localhost", port))
|
||||
if err != nil {
|
||||
t.Fatalf("expected localhost DNS resolution to succeed with whitelist, got %v", err)
|
||||
}
|
||||
conn.Close()
|
||||
|
||||
select {
|
||||
case <-accepted:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("expected localhost listener to accept a connection")
|
||||
}
|
||||
}
|
||||
|
||||
// TestIsPrivateOrRestrictedIP_Table tests IP classification logic
|
||||
func TestIsPrivateOrRestrictedIP_Table(t *testing.T) {
|
||||
tests := []struct {
|
||||
@@ -662,6 +799,16 @@ func TestNewWebFetchToolWithProxy(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewWebFetchToolWithConfig_InvalidPrivateHostWhitelist(t *testing.T) {
|
||||
_, err := NewWebFetchToolWithConfig(1024, "", testFetchLimit, []string{"not-an-ip-or-cidr"})
|
||||
if err == nil {
|
||||
t.Fatal("expected invalid whitelist entry to fail")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "invalid entry") {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewWebSearchTool_PropagatesProxy(t *testing.T) {
|
||||
t.Run("perplexity", func(t *testing.T) {
|
||||
tool, err := NewWebSearchTool(WebSearchToolOptions{
|
||||
|
||||
Reference in New Issue
Block a user