mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
fix(security): harden unauthenticated tool-exec paths (#1360)
* fix(security): harden unauthenticated tool-exec paths (GHSA-pv8c-p6jf-3fpp) - Exec tool: channel-based access control (default deny remote) - Cron tool: command scheduling restricted to internal channels - Web fetch: SSRF defense-in-depth (pre-flight + dial-time + redirect checks) - File permissions: session/state dirs 0700, files 0600 - Registry: inject __channel/__chat_id into tool args (replaces racy SetContext) 28 new security regression tests. (cherry picked from commit 191446ae19021604d3d5b0d9376b9655ab749105) * fix(exec): revalidate working_dir before command start * test(web): allow local oversized payload fixture --------- Co-authored-by: xj <gh-xj@users.noreply.github.com>
This commit is contained in:
@@ -209,7 +209,7 @@ func TestWeComAppVerifySignature(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("empty token skips verification", func(t *testing.T) {
|
||||
t.Run("empty token rejects verification (fail-closed)", func(t *testing.T) {
|
||||
cfgEmpty := config.WeComAppConfig{
|
||||
CorpID: "test_corp_id",
|
||||
CorpSecret: "test_secret",
|
||||
@@ -218,8 +218,8 @@ func TestWeComAppVerifySignature(t *testing.T) {
|
||||
}
|
||||
chEmpty, _ := NewWeComAppChannel(cfgEmpty, msgBus)
|
||||
|
||||
if !verifySignature(chEmpty.config.Token, "any_sig", "any_ts", "any_nonce", "any_msg") {
|
||||
t.Error("empty token should skip verification and return true")
|
||||
if verifySignature(chEmpty.config.Token, "any_sig", "any_ts", "any_nonce", "any_msg") {
|
||||
t.Error("empty token should reject verification (fail-closed)")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -189,8 +189,7 @@ func TestWeComBotVerifySignature(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("empty token skips verification", func(t *testing.T) {
|
||||
// Create a channel manually with empty token to test the behavior
|
||||
t.Run("empty token rejects verification (fail-closed)", func(t *testing.T) {
|
||||
cfgEmpty := config.WeComConfig{
|
||||
Token: "",
|
||||
WebhookURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test",
|
||||
@@ -199,8 +198,8 @@ func TestWeComBotVerifySignature(t *testing.T) {
|
||||
config: cfgEmpty,
|
||||
}
|
||||
|
||||
if !verifySignature(chEmpty.config.Token, "any_sig", "any_ts", "any_nonce", "any_msg") {
|
||||
t.Error("empty token should skip verification and return true")
|
||||
if verifySignature(chEmpty.config.Token, "any_sig", "any_ts", "any_nonce", "any_msg") {
|
||||
t.Error("empty token should reject verification (fail-closed)")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -31,7 +31,7 @@ func computeSignature(token, timestamp, nonce, encrypt string) string {
|
||||
// This is a common function used by both WeCom Bot and WeCom App
|
||||
func verifySignature(token, msgSignature, timestamp, nonce, msgEncrypt string) bool {
|
||||
if token == "" {
|
||||
return true // Skip verification if token is not set
|
||||
return false
|
||||
}
|
||||
return computeSignature(token, timestamp, nonce, msgEncrypt) == msgSignature
|
||||
}
|
||||
|
||||
@@ -673,6 +673,7 @@ type CronToolsConfig struct {
|
||||
type ExecConfig struct {
|
||||
ToolConfig ` envPrefix:"PICOCLAW_TOOLS_EXEC_"`
|
||||
EnableDenyPatterns bool ` env:"PICOCLAW_TOOLS_EXEC_ENABLE_DENY_PATTERNS" json:"enable_deny_patterns"`
|
||||
AllowRemote bool ` env:"PICOCLAW_TOOLS_EXEC_ALLOW_REMOTE" json:"allow_remote"`
|
||||
CustomDenyPatterns []string ` env:"PICOCLAW_TOOLS_EXEC_CUSTOM_DENY_PATTERNS" json:"custom_deny_patterns"`
|
||||
CustomAllowPatterns []string ` env:"PICOCLAW_TOOLS_EXEC_CUSTOM_ALLOW_PATTERNS" json:"custom_allow_patterns"`
|
||||
TimeoutSeconds int ` env:"PICOCLAW_TOOLS_EXEC_TIMEOUT_SECONDS" json:"timeout_seconds"` // 0 means use default (60s)
|
||||
|
||||
@@ -427,6 +427,7 @@ func DefaultConfig() *Config {
|
||||
Enabled: true,
|
||||
},
|
||||
EnableDenyPatterns: true,
|
||||
AllowRemote: false,
|
||||
TimeoutSeconds: 60,
|
||||
},
|
||||
Skills: SkillsToolsConfig{
|
||||
|
||||
@@ -32,7 +32,7 @@ func NewSessionManager(storage string) *SessionManager {
|
||||
}
|
||||
|
||||
if storage != "" {
|
||||
os.MkdirAll(storage, 0o755)
|
||||
os.MkdirAll(storage, 0o700)
|
||||
sm.loadSessions()
|
||||
}
|
||||
|
||||
@@ -216,7 +216,7 @@ func (sm *SessionManager) Save(key string) error {
|
||||
_ = tmpFile.Close()
|
||||
return err
|
||||
}
|
||||
if err := tmpFile.Chmod(0o644); err != nil {
|
||||
if err := tmpFile.Chmod(0o600); err != nil {
|
||||
_ = tmpFile.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
+2
-2
@@ -40,8 +40,8 @@ func NewManager(workspace string) *Manager {
|
||||
oldStateFile := filepath.Join(workspace, "state.json")
|
||||
|
||||
// Create state directory if it doesn't exist
|
||||
if err := os.MkdirAll(stateDir, 0o755); err != nil {
|
||||
log.Fatalf("[FATAL] state: failed to create state directory: %v", err)
|
||||
if err := os.MkdirAll(stateDir, 0o700); err != nil {
|
||||
log.Printf("[WARN] state: failed to create state directory %s: %v", stateDir, err)
|
||||
}
|
||||
|
||||
sm := &Manager{
|
||||
|
||||
+4
-12
@@ -2,7 +2,6 @@ package state
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
@@ -217,10 +216,7 @@ func TestNewManager_EmptyWorkspace(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewManager_MkdirFailureCrashes(t *testing.T) {
|
||||
// Since log.Fatalf calls os.Exit(1), we cannot test it normally
|
||||
// Otherwise, the test suite would stop altogether.
|
||||
// We use the standard pattern of Go: rerun this test in a subprocess.
|
||||
func TestNewManager_MkdirFailureDoesNotCrash(t *testing.T) {
|
||||
if os.Getenv("BE_CRASHER") == "1" {
|
||||
tmpDir := os.Getenv("CRASH_DIR")
|
||||
|
||||
@@ -240,15 +236,11 @@ func TestNewManager_MkdirFailureCrashes(t *testing.T) {
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cmd := exec.Command(os.Args[0], "-test.run=TestNewManager_MkdirFailureCrashes")
|
||||
cmd := exec.Command(os.Args[0], "-test.run=TestNewManager_MkdirFailureDoesNotCrash")
|
||||
cmd.Env = append(os.Environ(), "BE_CRASHER=1", "CRASH_DIR="+tmpDir)
|
||||
|
||||
err = cmd.Run()
|
||||
|
||||
var e *exec.ExitError
|
||||
if errors.As(err, &e) && !e.Success() {
|
||||
return
|
||||
if err != nil {
|
||||
t.Fatalf("NewManager should not crash when state dir creation fails, got: %v", err)
|
||||
}
|
||||
|
||||
t.Fatalf("The process ended without error, a crash was expected via os.Exit(1). Err: %v", err)
|
||||
}
|
||||
|
||||
+17
-5
@@ -8,6 +8,7 @@ import (
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/constants"
|
||||
"github.com/sipeed/picoclaw/pkg/cron"
|
||||
"github.com/sipeed/picoclaw/pkg/utils"
|
||||
)
|
||||
@@ -73,6 +74,10 @@ func (t *CronTool) Parameters() map[string]any {
|
||||
"type": "string",
|
||||
"description": "Optional: Shell command to execute directly (e.g., 'df -h'). If set, the agent will run this command and report output instead of just showing the message. 'deliver' will be forced to false for commands.",
|
||||
},
|
||||
"command_confirm": map[string]any{
|
||||
"type": "boolean",
|
||||
"description": "Required when using command=true. Must be true to explicitly confirm scheduling a shell command.",
|
||||
},
|
||||
"at_seconds": map[string]any{
|
||||
"type": "integer",
|
||||
"description": "One-time reminder: seconds from now when to trigger (e.g., 600 for 10 minutes later). Use this for one-time reminders like 'remind me in 10 minutes'.",
|
||||
@@ -175,12 +180,17 @@ func (t *CronTool) addJob(ctx context.Context, args map[string]any) *ToolResult
|
||||
deliver = d
|
||||
}
|
||||
|
||||
// GHSA-pv8c-p6jf-3fpp: command scheduling requires internal channel + explicit confirm.
|
||||
// Non-command reminders (plain messages) remain open to all channels.
|
||||
command, _ := args["command"].(string)
|
||||
commandConfirm, _ := args["command_confirm"].(bool)
|
||||
if command != "" {
|
||||
// Commands must be processed by agent/exec tool, so deliver must be false (or handled specifically)
|
||||
// Actually, let's keep deliver=false to let the system know it's not a simple chat message
|
||||
// But for our new logic in ExecuteJob, we can handle it regardless of deliver flag if Payload.Command is set.
|
||||
// However, logically, it's not "delivered" to chat directly as is.
|
||||
if !constants.IsInternalChannel(channel) {
|
||||
return ErrorResult("scheduling command execution is restricted to internal channels")
|
||||
}
|
||||
if !commandConfirm {
|
||||
return ErrorResult("command_confirm=true is required to schedule command execution")
|
||||
}
|
||||
deliver = false
|
||||
}
|
||||
|
||||
@@ -281,7 +291,9 @@ func (t *CronTool) ExecuteJob(ctx context.Context, job *cron.CronJob) string {
|
||||
// Execute command if present
|
||||
if job.Payload.Command != "" {
|
||||
args := map[string]any{
|
||||
"command": job.Payload.Command,
|
||||
"command": job.Payload.Command,
|
||||
"__channel": channel,
|
||||
"__chat_id": chatID,
|
||||
}
|
||||
|
||||
result := t.execTool.Execute(ctx, args)
|
||||
|
||||
@@ -0,0 +1,116 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/cron"
|
||||
)
|
||||
|
||||
func newTestCronTool(t *testing.T) *CronTool {
|
||||
t.Helper()
|
||||
storePath := filepath.Join(t.TempDir(), "cron.json")
|
||||
cronService := cron.NewCronService(storePath, nil)
|
||||
msgBus := bus.NewMessageBus()
|
||||
cfg := config.DefaultConfig()
|
||||
tool, err := NewCronTool(cronService, nil, msgBus, t.TempDir(), true, 0, cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("NewCronTool() error: %v", err)
|
||||
}
|
||||
return tool
|
||||
}
|
||||
|
||||
// TestCronTool_CommandBlockedFromRemoteChannel verifies command scheduling is restricted to internal channels
|
||||
func TestCronTool_CommandBlockedFromRemoteChannel(t *testing.T) {
|
||||
tool := newTestCronTool(t)
|
||||
ctx := WithToolContext(context.Background(), "telegram", "chat-1")
|
||||
result := tool.Execute(ctx, map[string]any{
|
||||
"action": "add",
|
||||
"message": "check disk",
|
||||
"command": "df -h",
|
||||
"command_confirm": true,
|
||||
"at_seconds": float64(60),
|
||||
})
|
||||
|
||||
if !result.IsError {
|
||||
t.Fatal("expected command scheduling to be blocked from remote channel")
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "restricted to internal channels") {
|
||||
t.Errorf("expected 'restricted to internal channels', got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCronTool_CommandRequiresConfirm verifies command_confirm=true is required
|
||||
func TestCronTool_CommandRequiresConfirm(t *testing.T) {
|
||||
tool := newTestCronTool(t)
|
||||
ctx := WithToolContext(context.Background(), "cli", "direct")
|
||||
result := tool.Execute(ctx, map[string]any{
|
||||
"action": "add",
|
||||
"message": "check disk",
|
||||
"command": "df -h",
|
||||
"at_seconds": float64(60),
|
||||
})
|
||||
|
||||
if !result.IsError {
|
||||
t.Fatal("expected error when command_confirm is missing")
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "command_confirm=true") {
|
||||
t.Errorf("expected 'command_confirm=true' message, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCronTool_CommandAllowedFromInternalChannel verifies command scheduling works from internal channels
|
||||
func TestCronTool_CommandAllowedFromInternalChannel(t *testing.T) {
|
||||
tool := newTestCronTool(t)
|
||||
ctx := WithToolContext(context.Background(), "cli", "direct")
|
||||
result := tool.Execute(ctx, map[string]any{
|
||||
"action": "add",
|
||||
"message": "check disk",
|
||||
"command": "df -h",
|
||||
"command_confirm": true,
|
||||
"at_seconds": float64(60),
|
||||
})
|
||||
|
||||
if result.IsError {
|
||||
t.Fatalf("expected command scheduling to succeed from internal channel, got: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "Cron job added") {
|
||||
t.Errorf("expected 'Cron job added', got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCronTool_AddJobRequiresSessionContext verifies fail-closed when channel/chatID missing
|
||||
func TestCronTool_AddJobRequiresSessionContext(t *testing.T) {
|
||||
tool := newTestCronTool(t)
|
||||
result := tool.Execute(context.Background(), map[string]any{
|
||||
"action": "add",
|
||||
"message": "reminder",
|
||||
"at_seconds": float64(60),
|
||||
})
|
||||
|
||||
if !result.IsError {
|
||||
t.Fatal("expected error when session context is missing")
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "no session context") {
|
||||
t.Errorf("expected 'no session context' message, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCronTool_NonCommandJobAllowedFromRemoteChannel verifies regular reminders work from any channel
|
||||
func TestCronTool_NonCommandJobAllowedFromRemoteChannel(t *testing.T) {
|
||||
tool := newTestCronTool(t)
|
||||
ctx := WithToolContext(context.Background(), "telegram", "chat-1")
|
||||
result := tool.Execute(ctx, map[string]any{
|
||||
"action": "add",
|
||||
"message": "time to stretch",
|
||||
"at_seconds": float64(600),
|
||||
})
|
||||
|
||||
if result.IsError {
|
||||
t.Fatalf("expected non-command reminder to succeed from remote channel, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/constants"
|
||||
)
|
||||
|
||||
type ExecTool struct {
|
||||
@@ -23,6 +24,7 @@ type ExecTool struct {
|
||||
allowPatterns []*regexp.Regexp
|
||||
customAllowPatterns []*regexp.Regexp
|
||||
restrictToWorkspace bool
|
||||
allowRemote bool
|
||||
}
|
||||
|
||||
var (
|
||||
@@ -100,10 +102,12 @@ func NewExecTool(workingDir string, restrict bool) (*ExecTool, error) {
|
||||
func NewExecToolWithConfig(workingDir string, restrict bool, config *config.Config) (*ExecTool, error) {
|
||||
denyPatterns := make([]*regexp.Regexp, 0)
|
||||
customAllowPatterns := make([]*regexp.Regexp, 0)
|
||||
allowRemote := true
|
||||
|
||||
if config != nil {
|
||||
execConfig := config.Tools.Exec
|
||||
enableDenyPatterns := execConfig.EnableDenyPatterns
|
||||
allowRemote = execConfig.AllowRemote
|
||||
if enableDenyPatterns {
|
||||
denyPatterns = append(denyPatterns, defaultDenyPatterns...)
|
||||
if len(execConfig.CustomDenyPatterns) > 0 {
|
||||
@@ -143,6 +147,7 @@ func NewExecToolWithConfig(workingDir string, restrict bool, config *config.Conf
|
||||
allowPatterns: nil,
|
||||
customAllowPatterns: customAllowPatterns,
|
||||
restrictToWorkspace: restrict,
|
||||
allowRemote: allowRemote,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -177,6 +182,19 @@ func (t *ExecTool) Execute(ctx context.Context, args map[string]any) *ToolResult
|
||||
return ErrorResult("command is required")
|
||||
}
|
||||
|
||||
// GHSA-pv8c-p6jf-3fpp: block exec from remote channels (e.g. Telegram webhooks)
|
||||
// unless explicitly opted-in via config. Fail-closed: empty channel = blocked.
|
||||
if !t.allowRemote {
|
||||
channel := ToolChannel(ctx)
|
||||
if channel == "" {
|
||||
channel, _ = args["__channel"].(string)
|
||||
}
|
||||
channel = strings.TrimSpace(channel)
|
||||
if channel == "" || !constants.IsInternalChannel(channel) {
|
||||
return ErrorResult("exec is restricted to internal channels")
|
||||
}
|
||||
}
|
||||
|
||||
cwd := t.workingDir
|
||||
if wd, ok := args["working_dir"].(string); ok && wd != "" {
|
||||
if t.restrictToWorkspace && t.workingDir != "" {
|
||||
@@ -201,6 +219,25 @@ func (t *ExecTool) Execute(ctx context.Context, args map[string]any) *ToolResult
|
||||
return ErrorResult(guardError)
|
||||
}
|
||||
|
||||
// Re-resolve symlinks immediately before execution to shrink the TOCTOU window
|
||||
// between validation and cmd.Dir assignment.
|
||||
if t.restrictToWorkspace && t.workingDir != "" && cwd != t.workingDir {
|
||||
resolved, err := filepath.EvalSymlinks(cwd)
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("Command blocked by safety guard (path resolution failed: %v)", err))
|
||||
}
|
||||
absWorkspace, _ := filepath.Abs(t.workingDir)
|
||||
wsResolved, _ := filepath.EvalSymlinks(absWorkspace)
|
||||
if wsResolved == "" {
|
||||
wsResolved = absWorkspace
|
||||
}
|
||||
rel, err := filepath.Rel(wsResolved, resolved)
|
||||
if err != nil || !filepath.IsLocal(rel) {
|
||||
return ErrorResult("Command blocked by safety guard (working directory escaped workspace)")
|
||||
}
|
||||
cwd = resolved
|
||||
}
|
||||
|
||||
// timeout == 0 means no timeout
|
||||
var cmdCtx context.Context
|
||||
var cancel context.CancelFunc
|
||||
|
||||
@@ -301,6 +301,85 @@ func TestShellTool_WorkingDir_SymlinkEscape(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestShellTool_RemoteChannelBlockedByDefault verifies exec is blocked for remote channels
|
||||
func TestShellTool_RemoteChannelBlockedByDefault(t *testing.T) {
|
||||
cfg := &config.Config{}
|
||||
cfg.Tools.Exec.EnableDenyPatterns = true
|
||||
cfg.Tools.Exec.AllowRemote = false
|
||||
|
||||
tool, err := NewExecToolWithConfig("", false, cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("NewExecToolWithConfig() error: %v", err)
|
||||
}
|
||||
ctx := WithToolContext(context.Background(), "telegram", "chat-1")
|
||||
result := tool.Execute(ctx, map[string]any{"command": "echo hi"})
|
||||
|
||||
if !result.IsError {
|
||||
t.Fatal("expected remote-channel exec to be blocked")
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "restricted to internal channels") {
|
||||
t.Errorf("expected 'restricted to internal channels' message, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
// TestShellTool_InternalChannelAllowed verifies exec is allowed for internal channels
|
||||
func TestShellTool_InternalChannelAllowed(t *testing.T) {
|
||||
cfg := &config.Config{}
|
||||
cfg.Tools.Exec.EnableDenyPatterns = true
|
||||
cfg.Tools.Exec.AllowRemote = false
|
||||
|
||||
tool, err := NewExecToolWithConfig("", false, cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("NewExecToolWithConfig() error: %v", err)
|
||||
}
|
||||
ctx := WithToolContext(context.Background(), "cli", "direct")
|
||||
result := tool.Execute(ctx, map[string]any{"command": "echo hi"})
|
||||
|
||||
if result.IsError {
|
||||
t.Fatalf("expected internal channel exec to succeed, got: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "hi") {
|
||||
t.Errorf("expected output to contain 'hi', got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
// TestShellTool_EmptyChannelBlockedWhenNotAllowRemote verifies fail-closed when no channel context
|
||||
func TestShellTool_EmptyChannelBlockedWhenNotAllowRemote(t *testing.T) {
|
||||
cfg := &config.Config{}
|
||||
cfg.Tools.Exec.EnableDenyPatterns = true
|
||||
cfg.Tools.Exec.AllowRemote = false
|
||||
|
||||
tool, err := NewExecToolWithConfig("", false, cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("NewExecToolWithConfig() error: %v", err)
|
||||
}
|
||||
result := tool.Execute(context.Background(), map[string]any{
|
||||
"command": "echo hi",
|
||||
})
|
||||
|
||||
if !result.IsError {
|
||||
t.Fatal("expected exec with empty channel to be blocked when allowRemote=false")
|
||||
}
|
||||
}
|
||||
|
||||
// TestShellTool_AllowRemoteBypassesChannelCheck verifies allowRemote=true permits any channel
|
||||
func TestShellTool_AllowRemoteBypassesChannelCheck(t *testing.T) {
|
||||
cfg := &config.Config{}
|
||||
cfg.Tools.Exec.EnableDenyPatterns = true
|
||||
cfg.Tools.Exec.AllowRemote = true
|
||||
|
||||
tool, err := NewExecToolWithConfig("", false, cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("NewExecToolWithConfig() error: %v", err)
|
||||
}
|
||||
ctx := WithToolContext(context.Background(), "telegram", "chat-1")
|
||||
result := tool.Execute(ctx, map[string]any{"command": "echo hi"})
|
||||
|
||||
if result.IsError {
|
||||
t.Fatalf("expected allowRemote=true to permit remote channel, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
// TestShellTool_RestrictToWorkspace verifies workspace restriction
|
||||
func TestShellTool_RestrictToWorkspace(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
+146
-1
@@ -7,6 +7,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"regexp"
|
||||
@@ -818,6 +819,10 @@ func NewWebFetchTool(maxChars int, fetchLimitBytes int64) (*WebFetchTool, error)
|
||||
return NewWebFetchToolWithProxy(maxChars, "", fetchLimitBytes)
|
||||
}
|
||||
|
||||
// allowPrivateWebFetchHosts controls whether loopback/private hosts are allowed.
|
||||
// This is false in normal runtime to reduce SSRF exposure, and tests can override it temporarily.
|
||||
var allowPrivateWebFetchHosts atomic.Bool
|
||||
|
||||
func NewWebFetchToolWithProxy(maxChars int, proxy string, fetchLimitBytes int64) (*WebFetchTool, error) {
|
||||
if maxChars <= 0 {
|
||||
maxChars = defaultMaxChars
|
||||
@@ -826,10 +831,20 @@ func NewWebFetchToolWithProxy(maxChars int, proxy string, fetchLimitBytes int64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create HTTP client for web fetch: %w", err)
|
||||
}
|
||||
if transport, ok := client.Transport.(*http.Transport); ok {
|
||||
dialer := &net.Dialer{
|
||||
Timeout: 15 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
}
|
||||
transport.DialContext = newSafeDialContext(dialer)
|
||||
}
|
||||
client.CheckRedirect = func(req *http.Request, via []*http.Request) error {
|
||||
if len(via) >= maxRedirects {
|
||||
return fmt.Errorf("stopped after %d redirects", maxRedirects)
|
||||
}
|
||||
if isObviousPrivateHost(req.URL.Hostname()) {
|
||||
return fmt.Errorf("redirect target is private or local network host")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if fetchLimitBytes <= 0 {
|
||||
@@ -888,6 +903,13 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]any) *ToolRe
|
||||
return ErrorResult("missing domain in URL")
|
||||
}
|
||||
|
||||
// Lightweight pre-flight: block obvious localhost/literal-IP without DNS resolution.
|
||||
// The real SSRF guard is newSafeDialContext at connect time.
|
||||
hostname := parsedURL.Hostname()
|
||||
if isObviousPrivateHost(hostname) {
|
||||
return ErrorResult("fetching private or local network hosts is not allowed")
|
||||
}
|
||||
|
||||
maxChars := t.maxChars
|
||||
if mc, ok := args["maxChars"].(float64); ok {
|
||||
if int(mc) > 100 {
|
||||
@@ -901,7 +923,6 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]any) *ToolRe
|
||||
}
|
||||
|
||||
req.Header.Set("User-Agent", userAgent)
|
||||
|
||||
resp, err := t.client.Do(req)
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("request failed: %v", err))
|
||||
@@ -992,3 +1013,127 @@ func (t *WebFetchTool) extractText(htmlContent string) string {
|
||||
|
||||
return strings.Join(cleanLines, "\n")
|
||||
}
|
||||
|
||||
// newSafeDialContext re-resolves DNS at connect time to mitigate DNS rebinding (TOCTOU)
|
||||
// where a hostname resolves to a public IP during pre-flight but a private IP at connect time.
|
||||
func newSafeDialContext(dialer *net.Dialer) func(context.Context, string, string) (net.Conn, error) {
|
||||
return func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
if allowPrivateWebFetchHosts.Load() {
|
||||
return dialer.DialContext(ctx, network, address)
|
||||
}
|
||||
|
||||
host, port, err := net.SplitHostPort(address)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid target address %q: %w", address, err)
|
||||
}
|
||||
if host == "" {
|
||||
return nil, fmt.Errorf("empty target host")
|
||||
}
|
||||
|
||||
if ip := net.ParseIP(host); ip != nil {
|
||||
if isPrivateOrRestrictedIP(ip) {
|
||||
return nil, fmt.Errorf("blocked private or local target: %s", host)
|
||||
}
|
||||
return dialer.DialContext(ctx, network, net.JoinHostPort(ip.String(), port))
|
||||
}
|
||||
|
||||
ipAddrs, err := net.DefaultResolver.LookupIPAddr(ctx, host)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to resolve %s: %w", host, err)
|
||||
}
|
||||
|
||||
attempted := 0
|
||||
var lastErr error
|
||||
for _, ipAddr := range ipAddrs {
|
||||
if isPrivateOrRestrictedIP(ipAddr.IP) {
|
||||
continue
|
||||
}
|
||||
attempted++
|
||||
conn, err := dialer.DialContext(ctx, network, net.JoinHostPort(ipAddr.IP.String(), port))
|
||||
if err == nil {
|
||||
return conn, nil
|
||||
}
|
||||
lastErr = err
|
||||
}
|
||||
|
||||
if attempted == 0 {
|
||||
return nil, fmt.Errorf("all resolved addresses for %s are private or restricted", host)
|
||||
}
|
||||
if lastErr != nil {
|
||||
return nil, fmt.Errorf("failed connecting to public addresses for %s: %w", host, lastErr)
|
||||
}
|
||||
return nil, fmt.Errorf("failed connecting to public addresses for %s", host)
|
||||
}
|
||||
}
|
||||
|
||||
// isObviousPrivateHost performs a lightweight, no-DNS check for obviously private hosts.
|
||||
// It catches localhost, literal private IPs, and empty hosts. It does NOT resolve DNS —
|
||||
// the real SSRF guard is newSafeDialContext which checks IPs at connect time.
|
||||
func isObviousPrivateHost(host string) bool {
|
||||
if allowPrivateWebFetchHosts.Load() {
|
||||
return false
|
||||
}
|
||||
|
||||
h := strings.ToLower(strings.TrimSpace(host))
|
||||
h = strings.TrimSuffix(h, ".")
|
||||
if h == "" {
|
||||
return true
|
||||
}
|
||||
|
||||
if h == "localhost" || strings.HasSuffix(h, ".localhost") {
|
||||
return true
|
||||
}
|
||||
|
||||
if ip := net.ParseIP(h); ip != nil {
|
||||
return isPrivateOrRestrictedIP(ip)
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// isPrivateOrRestrictedIP returns true for IPs that should never be reached via web_fetch:
|
||||
// RFC 1918, loopback, link-local (incl. cloud metadata 169.254.x.x), carrier-grade NAT,
|
||||
// IPv6 unique-local (fc00::/7), 6to4 (2002::/16), and Teredo (2001:0000::/32).
|
||||
func isPrivateOrRestrictedIP(ip net.IP) bool {
|
||||
if ip == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() ||
|
||||
ip.IsMulticast() || ip.IsUnspecified() {
|
||||
return true
|
||||
}
|
||||
|
||||
if ip4 := ip.To4(); ip4 != nil {
|
||||
// IPv4 private, loopback, link-local, and carrier-grade NAT ranges.
|
||||
if ip4[0] == 10 ||
|
||||
ip4[0] == 127 ||
|
||||
ip4[0] == 0 ||
|
||||
(ip4[0] == 172 && ip4[1] >= 16 && ip4[1] <= 31) ||
|
||||
(ip4[0] == 192 && ip4[1] == 168) ||
|
||||
(ip4[0] == 169 && ip4[1] == 254) ||
|
||||
(ip4[0] == 100 && ip4[1] >= 64 && ip4[1] <= 127) {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
if len(ip) == net.IPv6len {
|
||||
// IPv6 unique local addresses (fc00::/7)
|
||||
if (ip[0] & 0xfe) == 0xfc {
|
||||
return true
|
||||
}
|
||||
// 6to4 addresses (2002::/16): check the embedded IPv4 at bytes [2:6].
|
||||
if ip[0] == 0x20 && ip[1] == 0x02 {
|
||||
embedded := net.IPv4(ip[2], ip[3], ip[4], ip[5])
|
||||
return isPrivateOrRestrictedIP(embedded)
|
||||
}
|
||||
// Teredo (2001:0000::/32): client IPv4 is at bytes [12:16], XOR-inverted.
|
||||
if ip[0] == 0x20 && ip[1] == 0x01 && ip[2] == 0x00 && ip[3] == 0x00 {
|
||||
client := net.IPv4(ip[12]^0xff, ip[13]^0xff, ip[14]^0xff, ip[15]^0xff)
|
||||
return isPrivateOrRestrictedIP(client)
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
@@ -18,6 +19,8 @@ const testFetchLimit = int64(10 * 1024 * 1024)
|
||||
|
||||
// TestWebTool_WebFetch_Success verifies successful URL fetching
|
||||
func TestWebTool_WebFetch_Success(t *testing.T) {
|
||||
withPrivateWebFetchHostsAllowed(t)
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/html")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
@@ -55,6 +58,8 @@ func TestWebTool_WebFetch_Success(t *testing.T) {
|
||||
|
||||
// TestWebTool_WebFetch_JSON verifies JSON content handling
|
||||
func TestWebTool_WebFetch_JSON(t *testing.T) {
|
||||
withPrivateWebFetchHostsAllowed(t)
|
||||
|
||||
testData := map[string]string{"key": "value", "number": "123"}
|
||||
expectedJSON, _ := json.MarshalIndent(testData, "", " ")
|
||||
|
||||
@@ -163,6 +168,8 @@ func TestWebTool_WebFetch_MissingURL(t *testing.T) {
|
||||
|
||||
// TestWebTool_WebFetch_Truncation verifies content truncation
|
||||
func TestWebTool_WebFetch_Truncation(t *testing.T) {
|
||||
withPrivateWebFetchHostsAllowed(t)
|
||||
|
||||
longContent := strings.Repeat("x", 20000)
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -205,6 +212,8 @@ func TestWebTool_WebFetch_Truncation(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestWebFetchTool_PayloadTooLarge(t *testing.T) {
|
||||
withPrivateWebFetchHostsAllowed(t)
|
||||
|
||||
// Create a mock HTTP server
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/html")
|
||||
@@ -290,6 +299,8 @@ func TestWebTool_WebSearch_MissingQuery(t *testing.T) {
|
||||
|
||||
// TestWebTool_WebFetch_HTMLExtraction verifies HTML text extraction
|
||||
func TestWebTool_WebFetch_HTMLExtraction(t *testing.T) {
|
||||
withPrivateWebFetchHostsAllowed(t)
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/html")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
@@ -404,6 +415,205 @@ func TestWebFetchTool_extractText(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func withPrivateWebFetchHostsAllowed(t *testing.T) {
|
||||
t.Helper()
|
||||
previous := allowPrivateWebFetchHosts.Load()
|
||||
allowPrivateWebFetchHosts.Store(true)
|
||||
t.Cleanup(func() {
|
||||
allowPrivateWebFetchHosts.Store(previous)
|
||||
})
|
||||
}
|
||||
|
||||
func TestWebTool_WebFetch_PrivateHostBlocked(t *testing.T) {
|
||||
tool, err := NewWebFetchTool(50000, testFetchLimit)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create web fetch tool: %v", err)
|
||||
}
|
||||
result := tool.Execute(context.Background(), map[string]any{
|
||||
"url": "http://127.0.0.1:0",
|
||||
})
|
||||
|
||||
if !result.IsError {
|
||||
t.Errorf("expected error for private host URL, got success")
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "private or local network") &&
|
||||
!strings.Contains(result.ForUser, "private or local network") {
|
||||
t.Errorf("expected private host block message, got %q", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebTool_WebFetch_PrivateHostAllowedForTests(t *testing.T) {
|
||||
withPrivateWebFetchHostsAllowed(t)
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("ok"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
tool, err := NewWebFetchTool(50000, testFetchLimit)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create web fetch tool: %v", err)
|
||||
}
|
||||
result := tool.Execute(context.Background(), map[string]any{
|
||||
"url": server.URL,
|
||||
})
|
||||
|
||||
if result.IsError {
|
||||
t.Errorf("expected success when private host access is allowed in tests, got %q", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebFetch_BlocksIPv4MappedIPv6Loopback verifies ::ffff:127.0.0.1 is blocked
|
||||
func TestWebFetch_BlocksIPv4MappedIPv6Loopback(t *testing.T) {
|
||||
tool, err := NewWebFetchTool(50000, testFetchLimit)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create web fetch tool: %v", err)
|
||||
}
|
||||
result := tool.Execute(context.Background(), map[string]any{
|
||||
"url": "http://[::ffff:127.0.0.1]:0",
|
||||
})
|
||||
|
||||
if !result.IsError {
|
||||
t.Error("expected error for IPv4-mapped IPv6 loopback URL, got success")
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebFetch_BlocksMetadataIP verifies 169.254.169.254 is blocked
|
||||
func TestWebFetch_BlocksMetadataIP(t *testing.T) {
|
||||
tool, err := NewWebFetchTool(50000, testFetchLimit)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create web fetch tool: %v", err)
|
||||
}
|
||||
result := tool.Execute(context.Background(), map[string]any{
|
||||
"url": "http://169.254.169.254/latest/meta-data",
|
||||
})
|
||||
|
||||
if !result.IsError {
|
||||
t.Error("expected error for cloud metadata IP, got success")
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebFetch_BlocksIPv6UniqueLocal verifies fc00::/7 addresses are blocked
|
||||
func TestWebFetch_BlocksIPv6UniqueLocal(t *testing.T) {
|
||||
tool, err := NewWebFetchTool(50000, testFetchLimit)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create web fetch tool: %v", err)
|
||||
}
|
||||
result := tool.Execute(context.Background(), map[string]any{
|
||||
"url": "http://[fd00::1]:0",
|
||||
})
|
||||
|
||||
if !result.IsError {
|
||||
t.Error("expected error for IPv6 unique local address, got success")
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebFetch_Blocks6to4WithPrivateEmbed verifies 6to4 with private embedded IPv4 is blocked
|
||||
func TestWebFetch_Blocks6to4WithPrivateEmbed(t *testing.T) {
|
||||
tool, err := NewWebFetchTool(50000, testFetchLimit)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create web fetch tool: %v", err)
|
||||
}
|
||||
// 2002:7f00:0001::1 embeds 127.0.0.1
|
||||
result := tool.Execute(context.Background(), map[string]any{
|
||||
"url": "http://[2002:7f00:0001::1]:0",
|
||||
})
|
||||
|
||||
if !result.IsError {
|
||||
t.Error("expected error for 6to4 with private embedded IPv4, got success")
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebFetch_Allows6to4WithPublicEmbed verifies 6to4 with public embedded IPv4 is NOT blocked
|
||||
func TestWebFetch_Allows6to4WithPublicEmbed(t *testing.T) {
|
||||
tool, err := NewWebFetchTool(50000, testFetchLimit)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create web fetch tool: %v", err)
|
||||
}
|
||||
// 2002:0801:0101::1 embeds 8.1.1.1 (public) — pre-flight should pass,
|
||||
// connection will fail (no listener) but that's after the SSRF check.
|
||||
result := tool.Execute(context.Background(), map[string]any{
|
||||
"url": "http://[2002:0801:0101::1]:0",
|
||||
})
|
||||
|
||||
// Should NOT be blocked by SSRF check — error should be connection failure, not "private"
|
||||
if result.IsError && strings.Contains(result.ForLLM, "private") {
|
||||
t.Error("6to4 with public embedded IPv4 should not be blocked as private")
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebFetch_RedirectToPrivateBlocked verifies redirects to private IPs are blocked
|
||||
func TestWebFetch_RedirectToPrivateBlocked(t *testing.T) {
|
||||
withPrivateWebFetchHostsAllowed(t)
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Redirect to a private IP
|
||||
http.Redirect(w, r, "http://10.0.0.1/secret", http.StatusFound)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
// Temporarily disable private host allowance for the redirect check
|
||||
allowPrivateWebFetchHosts.Store(false)
|
||||
defer allowPrivateWebFetchHosts.Store(true)
|
||||
|
||||
tool, err := NewWebFetchTool(50000, testFetchLimit)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create web fetch tool: %v", err)
|
||||
}
|
||||
result := tool.Execute(context.Background(), map[string]any{
|
||||
"url": server.URL,
|
||||
})
|
||||
|
||||
if !result.IsError {
|
||||
t.Error("expected error when redirecting to private IP, got success")
|
||||
}
|
||||
}
|
||||
|
||||
// TestIsPrivateOrRestrictedIP_Table tests IP classification logic
|
||||
func TestIsPrivateOrRestrictedIP_Table(t *testing.T) {
|
||||
tests := []struct {
|
||||
ip string
|
||||
blocked bool
|
||||
desc string
|
||||
}{
|
||||
{"127.0.0.1", true, "IPv4 loopback"},
|
||||
{"10.0.0.1", true, "IPv4 private class A"},
|
||||
{"172.16.0.1", true, "IPv4 private class B"},
|
||||
{"192.168.1.1", true, "IPv4 private class C"},
|
||||
{"169.254.169.254", true, "link-local / cloud metadata"},
|
||||
{"100.64.0.1", true, "carrier-grade NAT"},
|
||||
{"0.0.0.0", true, "unspecified"},
|
||||
{"8.8.8.8", false, "public DNS"},
|
||||
{"1.1.1.1", false, "public DNS"},
|
||||
{"::1", true, "IPv6 loopback"},
|
||||
{"::ffff:127.0.0.1", true, "IPv4-mapped IPv6 loopback"},
|
||||
{"::ffff:10.0.0.1", true, "IPv4-mapped IPv6 private"},
|
||||
{"fc00::1", true, "IPv6 unique local"},
|
||||
{"fd00::1", true, "IPv6 unique local"},
|
||||
{"2002:7f00:0001::1", true, "6to4 with embedded 127.x (private)"},
|
||||
{"2002:0a00:0001::1", true, "6to4 with embedded 10.0.0.1 (private)"},
|
||||
{"2002:0801:0101::1", false, "6to4 with embedded 8.1.1.1 (public)"},
|
||||
{"2001:0000:4136:e378:8000:63bf:f5ff:fffe", true, "Teredo with client 10.0.0.1 (private)"},
|
||||
{"2001:0000:4136:e378:8000:63bf:f7f6:fefe", false, "Teredo with client 8.9.1.1 (public)"},
|
||||
{"2607:f8b0:4004:800::200e", false, "public IPv6 (Google)"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.desc, func(t *testing.T) {
|
||||
ip := net.ParseIP(tt.ip)
|
||||
if ip == nil {
|
||||
t.Fatalf("failed to parse IP: %s", tt.ip)
|
||||
}
|
||||
got := isPrivateOrRestrictedIP(ip)
|
||||
if got != tt.blocked {
|
||||
t.Errorf("isPrivateOrRestrictedIP(%s) = %v, want %v", tt.ip, got, tt.blocked)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebTool_WebFetch_MissingDomain verifies error handling for URL without domain
|
||||
func TestWebTool_WebFetch_MissingDomain(t *testing.T) {
|
||||
tool, err := NewWebFetchTool(50000, testFetchLimit)
|
||||
|
||||
Reference in New Issue
Block a user