diff --git a/pkg/channels/wecom/app_test.go b/pkg/channels/wecom/app_test.go index 7f230494f..7d07041ad 100644 --- a/pkg/channels/wecom/app_test.go +++ b/pkg/channels/wecom/app_test.go @@ -209,7 +209,7 @@ func TestWeComAppVerifySignature(t *testing.T) { } }) - t.Run("empty token skips verification", func(t *testing.T) { + t.Run("empty token rejects verification (fail-closed)", func(t *testing.T) { cfgEmpty := config.WeComAppConfig{ CorpID: "test_corp_id", CorpSecret: "test_secret", @@ -218,8 +218,8 @@ func TestWeComAppVerifySignature(t *testing.T) { } chEmpty, _ := NewWeComAppChannel(cfgEmpty, msgBus) - if !verifySignature(chEmpty.config.Token, "any_sig", "any_ts", "any_nonce", "any_msg") { - t.Error("empty token should skip verification and return true") + if verifySignature(chEmpty.config.Token, "any_sig", "any_ts", "any_nonce", "any_msg") { + t.Error("empty token should reject verification (fail-closed)") } }) } diff --git a/pkg/channels/wecom/bot_test.go b/pkg/channels/wecom/bot_test.go index c053578b1..d223bb6b6 100644 --- a/pkg/channels/wecom/bot_test.go +++ b/pkg/channels/wecom/bot_test.go @@ -189,8 +189,7 @@ func TestWeComBotVerifySignature(t *testing.T) { } }) - t.Run("empty token skips verification", func(t *testing.T) { - // Create a channel manually with empty token to test the behavior + t.Run("empty token rejects verification (fail-closed)", func(t *testing.T) { cfgEmpty := config.WeComConfig{ Token: "", WebhookURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", @@ -199,8 +198,8 @@ func TestWeComBotVerifySignature(t *testing.T) { config: cfgEmpty, } - if !verifySignature(chEmpty.config.Token, "any_sig", "any_ts", "any_nonce", "any_msg") { - t.Error("empty token should skip verification and return true") + if verifySignature(chEmpty.config.Token, "any_sig", "any_ts", "any_nonce", "any_msg") { + t.Error("empty token should reject verification (fail-closed)") } }) } diff --git a/pkg/channels/wecom/common.go b/pkg/channels/wecom/common.go index 6510e6f81..9a622a2fc 100644 --- a/pkg/channels/wecom/common.go +++ b/pkg/channels/wecom/common.go @@ -31,7 +31,7 @@ func computeSignature(token, timestamp, nonce, encrypt string) string { // This is a common function used by both WeCom Bot and WeCom App func verifySignature(token, msgSignature, timestamp, nonce, msgEncrypt string) bool { if token == "" { - return true // Skip verification if token is not set + return false } return computeSignature(token, timestamp, nonce, msgEncrypt) == msgSignature } diff --git a/pkg/config/config.go b/pkg/config/config.go index 7a806c1e1..161290c64 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -673,6 +673,7 @@ type CronToolsConfig struct { type ExecConfig struct { ToolConfig ` envPrefix:"PICOCLAW_TOOLS_EXEC_"` EnableDenyPatterns bool ` env:"PICOCLAW_TOOLS_EXEC_ENABLE_DENY_PATTERNS" json:"enable_deny_patterns"` + AllowRemote bool ` env:"PICOCLAW_TOOLS_EXEC_ALLOW_REMOTE" json:"allow_remote"` CustomDenyPatterns []string ` env:"PICOCLAW_TOOLS_EXEC_CUSTOM_DENY_PATTERNS" json:"custom_deny_patterns"` CustomAllowPatterns []string ` env:"PICOCLAW_TOOLS_EXEC_CUSTOM_ALLOW_PATTERNS" json:"custom_allow_patterns"` TimeoutSeconds int ` env:"PICOCLAW_TOOLS_EXEC_TIMEOUT_SECONDS" json:"timeout_seconds"` // 0 means use default (60s) diff --git a/pkg/config/defaults.go b/pkg/config/defaults.go index 3b1bb1aef..56cc95375 100644 --- a/pkg/config/defaults.go +++ b/pkg/config/defaults.go @@ -427,6 +427,7 @@ func DefaultConfig() *Config { Enabled: true, }, EnableDenyPatterns: true, + AllowRemote: false, TimeoutSeconds: 60, }, Skills: SkillsToolsConfig{ diff --git a/pkg/session/manager.go b/pkg/session/manager.go index a31dbd55c..ef720b7c5 100644 --- a/pkg/session/manager.go +++ b/pkg/session/manager.go @@ -32,7 +32,7 @@ func NewSessionManager(storage string) *SessionManager { } if storage != "" { - os.MkdirAll(storage, 0o755) + os.MkdirAll(storage, 0o700) sm.loadSessions() } @@ -216,7 +216,7 @@ func (sm *SessionManager) Save(key string) error { _ = tmpFile.Close() return err } - if err := tmpFile.Chmod(0o644); err != nil { + if err := tmpFile.Chmod(0o600); err != nil { _ = tmpFile.Close() return err } diff --git a/pkg/state/state.go b/pkg/state/state.go index 57f371f12..5da7bbde1 100644 --- a/pkg/state/state.go +++ b/pkg/state/state.go @@ -40,8 +40,8 @@ func NewManager(workspace string) *Manager { oldStateFile := filepath.Join(workspace, "state.json") // Create state directory if it doesn't exist - if err := os.MkdirAll(stateDir, 0o755); err != nil { - log.Fatalf("[FATAL] state: failed to create state directory: %v", err) + if err := os.MkdirAll(stateDir, 0o700); err != nil { + log.Printf("[WARN] state: failed to create state directory %s: %v", stateDir, err) } sm := &Manager{ diff --git a/pkg/state/state_test.go b/pkg/state/state_test.go index e5e116ef6..3924e5533 100644 --- a/pkg/state/state_test.go +++ b/pkg/state/state_test.go @@ -2,7 +2,6 @@ package state import ( "encoding/json" - "errors" "fmt" "os" "os/exec" @@ -217,10 +216,7 @@ func TestNewManager_EmptyWorkspace(t *testing.T) { } } -func TestNewManager_MkdirFailureCrashes(t *testing.T) { - // Since log.Fatalf calls os.Exit(1), we cannot test it normally - // Otherwise, the test suite would stop altogether. - // We use the standard pattern of Go: rerun this test in a subprocess. +func TestNewManager_MkdirFailureDoesNotCrash(t *testing.T) { if os.Getenv("BE_CRASHER") == "1" { tmpDir := os.Getenv("CRASH_DIR") @@ -240,15 +236,11 @@ func TestNewManager_MkdirFailureCrashes(t *testing.T) { } defer os.RemoveAll(tmpDir) - cmd := exec.Command(os.Args[0], "-test.run=TestNewManager_MkdirFailureCrashes") + cmd := exec.Command(os.Args[0], "-test.run=TestNewManager_MkdirFailureDoesNotCrash") cmd.Env = append(os.Environ(), "BE_CRASHER=1", "CRASH_DIR="+tmpDir) err = cmd.Run() - - var e *exec.ExitError - if errors.As(err, &e) && !e.Success() { - return + if err != nil { + t.Fatalf("NewManager should not crash when state dir creation fails, got: %v", err) } - - t.Fatalf("The process ended without error, a crash was expected via os.Exit(1). Err: %v", err) } diff --git a/pkg/tools/cron.go b/pkg/tools/cron.go index 6af0aa9e1..648cc3c6c 100644 --- a/pkg/tools/cron.go +++ b/pkg/tools/cron.go @@ -8,6 +8,7 @@ import ( "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/constants" "github.com/sipeed/picoclaw/pkg/cron" "github.com/sipeed/picoclaw/pkg/utils" ) @@ -73,6 +74,10 @@ func (t *CronTool) Parameters() map[string]any { "type": "string", "description": "Optional: Shell command to execute directly (e.g., 'df -h'). If set, the agent will run this command and report output instead of just showing the message. 'deliver' will be forced to false for commands.", }, + "command_confirm": map[string]any{ + "type": "boolean", + "description": "Required when using command=true. Must be true to explicitly confirm scheduling a shell command.", + }, "at_seconds": map[string]any{ "type": "integer", "description": "One-time reminder: seconds from now when to trigger (e.g., 600 for 10 minutes later). Use this for one-time reminders like 'remind me in 10 minutes'.", @@ -175,12 +180,17 @@ func (t *CronTool) addJob(ctx context.Context, args map[string]any) *ToolResult deliver = d } + // GHSA-pv8c-p6jf-3fpp: command scheduling requires internal channel + explicit confirm. + // Non-command reminders (plain messages) remain open to all channels. command, _ := args["command"].(string) + commandConfirm, _ := args["command_confirm"].(bool) if command != "" { - // Commands must be processed by agent/exec tool, so deliver must be false (or handled specifically) - // Actually, let's keep deliver=false to let the system know it's not a simple chat message - // But for our new logic in ExecuteJob, we can handle it regardless of deliver flag if Payload.Command is set. - // However, logically, it's not "delivered" to chat directly as is. + if !constants.IsInternalChannel(channel) { + return ErrorResult("scheduling command execution is restricted to internal channels") + } + if !commandConfirm { + return ErrorResult("command_confirm=true is required to schedule command execution") + } deliver = false } @@ -281,7 +291,9 @@ func (t *CronTool) ExecuteJob(ctx context.Context, job *cron.CronJob) string { // Execute command if present if job.Payload.Command != "" { args := map[string]any{ - "command": job.Payload.Command, + "command": job.Payload.Command, + "__channel": channel, + "__chat_id": chatID, } result := t.execTool.Execute(ctx, args) diff --git a/pkg/tools/cron_test.go b/pkg/tools/cron_test.go new file mode 100644 index 000000000..1776abc65 --- /dev/null +++ b/pkg/tools/cron_test.go @@ -0,0 +1,116 @@ +package tools + +import ( + "context" + "path/filepath" + "strings" + "testing" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/cron" +) + +func newTestCronTool(t *testing.T) *CronTool { + t.Helper() + storePath := filepath.Join(t.TempDir(), "cron.json") + cronService := cron.NewCronService(storePath, nil) + msgBus := bus.NewMessageBus() + cfg := config.DefaultConfig() + tool, err := NewCronTool(cronService, nil, msgBus, t.TempDir(), true, 0, cfg) + if err != nil { + t.Fatalf("NewCronTool() error: %v", err) + } + return tool +} + +// TestCronTool_CommandBlockedFromRemoteChannel verifies command scheduling is restricted to internal channels +func TestCronTool_CommandBlockedFromRemoteChannel(t *testing.T) { + tool := newTestCronTool(t) + ctx := WithToolContext(context.Background(), "telegram", "chat-1") + result := tool.Execute(ctx, map[string]any{ + "action": "add", + "message": "check disk", + "command": "df -h", + "command_confirm": true, + "at_seconds": float64(60), + }) + + if !result.IsError { + t.Fatal("expected command scheduling to be blocked from remote channel") + } + if !strings.Contains(result.ForLLM, "restricted to internal channels") { + t.Errorf("expected 'restricted to internal channels', got: %s", result.ForLLM) + } +} + +// TestCronTool_CommandRequiresConfirm verifies command_confirm=true is required +func TestCronTool_CommandRequiresConfirm(t *testing.T) { + tool := newTestCronTool(t) + ctx := WithToolContext(context.Background(), "cli", "direct") + result := tool.Execute(ctx, map[string]any{ + "action": "add", + "message": "check disk", + "command": "df -h", + "at_seconds": float64(60), + }) + + if !result.IsError { + t.Fatal("expected error when command_confirm is missing") + } + if !strings.Contains(result.ForLLM, "command_confirm=true") { + t.Errorf("expected 'command_confirm=true' message, got: %s", result.ForLLM) + } +} + +// TestCronTool_CommandAllowedFromInternalChannel verifies command scheduling works from internal channels +func TestCronTool_CommandAllowedFromInternalChannel(t *testing.T) { + tool := newTestCronTool(t) + ctx := WithToolContext(context.Background(), "cli", "direct") + result := tool.Execute(ctx, map[string]any{ + "action": "add", + "message": "check disk", + "command": "df -h", + "command_confirm": true, + "at_seconds": float64(60), + }) + + if result.IsError { + t.Fatalf("expected command scheduling to succeed from internal channel, got: %s", result.ForLLM) + } + if !strings.Contains(result.ForLLM, "Cron job added") { + t.Errorf("expected 'Cron job added', got: %s", result.ForLLM) + } +} + +// TestCronTool_AddJobRequiresSessionContext verifies fail-closed when channel/chatID missing +func TestCronTool_AddJobRequiresSessionContext(t *testing.T) { + tool := newTestCronTool(t) + result := tool.Execute(context.Background(), map[string]any{ + "action": "add", + "message": "reminder", + "at_seconds": float64(60), + }) + + if !result.IsError { + t.Fatal("expected error when session context is missing") + } + if !strings.Contains(result.ForLLM, "no session context") { + t.Errorf("expected 'no session context' message, got: %s", result.ForLLM) + } +} + +// TestCronTool_NonCommandJobAllowedFromRemoteChannel verifies regular reminders work from any channel +func TestCronTool_NonCommandJobAllowedFromRemoteChannel(t *testing.T) { + tool := newTestCronTool(t) + ctx := WithToolContext(context.Background(), "telegram", "chat-1") + result := tool.Execute(ctx, map[string]any{ + "action": "add", + "message": "time to stretch", + "at_seconds": float64(600), + }) + + if result.IsError { + t.Fatalf("expected non-command reminder to succeed from remote channel, got: %s", result.ForLLM) + } +} diff --git a/pkg/tools/shell.go b/pkg/tools/shell.go index b8a811d03..67e2ad257 100644 --- a/pkg/tools/shell.go +++ b/pkg/tools/shell.go @@ -14,6 +14,7 @@ import ( "time" "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/constants" ) type ExecTool struct { @@ -23,6 +24,7 @@ type ExecTool struct { allowPatterns []*regexp.Regexp customAllowPatterns []*regexp.Regexp restrictToWorkspace bool + allowRemote bool } var ( @@ -100,10 +102,12 @@ func NewExecTool(workingDir string, restrict bool) (*ExecTool, error) { func NewExecToolWithConfig(workingDir string, restrict bool, config *config.Config) (*ExecTool, error) { denyPatterns := make([]*regexp.Regexp, 0) customAllowPatterns := make([]*regexp.Regexp, 0) + allowRemote := true if config != nil { execConfig := config.Tools.Exec enableDenyPatterns := execConfig.EnableDenyPatterns + allowRemote = execConfig.AllowRemote if enableDenyPatterns { denyPatterns = append(denyPatterns, defaultDenyPatterns...) if len(execConfig.CustomDenyPatterns) > 0 { @@ -143,6 +147,7 @@ func NewExecToolWithConfig(workingDir string, restrict bool, config *config.Conf allowPatterns: nil, customAllowPatterns: customAllowPatterns, restrictToWorkspace: restrict, + allowRemote: allowRemote, }, nil } @@ -177,6 +182,19 @@ func (t *ExecTool) Execute(ctx context.Context, args map[string]any) *ToolResult return ErrorResult("command is required") } + // GHSA-pv8c-p6jf-3fpp: block exec from remote channels (e.g. Telegram webhooks) + // unless explicitly opted-in via config. Fail-closed: empty channel = blocked. + if !t.allowRemote { + channel := ToolChannel(ctx) + if channel == "" { + channel, _ = args["__channel"].(string) + } + channel = strings.TrimSpace(channel) + if channel == "" || !constants.IsInternalChannel(channel) { + return ErrorResult("exec is restricted to internal channels") + } + } + cwd := t.workingDir if wd, ok := args["working_dir"].(string); ok && wd != "" { if t.restrictToWorkspace && t.workingDir != "" { @@ -201,6 +219,25 @@ func (t *ExecTool) Execute(ctx context.Context, args map[string]any) *ToolResult return ErrorResult(guardError) } + // Re-resolve symlinks immediately before execution to shrink the TOCTOU window + // between validation and cmd.Dir assignment. + if t.restrictToWorkspace && t.workingDir != "" && cwd != t.workingDir { + resolved, err := filepath.EvalSymlinks(cwd) + if err != nil { + return ErrorResult(fmt.Sprintf("Command blocked by safety guard (path resolution failed: %v)", err)) + } + absWorkspace, _ := filepath.Abs(t.workingDir) + wsResolved, _ := filepath.EvalSymlinks(absWorkspace) + if wsResolved == "" { + wsResolved = absWorkspace + } + rel, err := filepath.Rel(wsResolved, resolved) + if err != nil || !filepath.IsLocal(rel) { + return ErrorResult("Command blocked by safety guard (working directory escaped workspace)") + } + cwd = resolved + } + // timeout == 0 means no timeout var cmdCtx context.Context var cancel context.CancelFunc diff --git a/pkg/tools/shell_test.go b/pkg/tools/shell_test.go index ff9ea4a15..90265e5bd 100644 --- a/pkg/tools/shell_test.go +++ b/pkg/tools/shell_test.go @@ -301,6 +301,85 @@ func TestShellTool_WorkingDir_SymlinkEscape(t *testing.T) { } } +// TestShellTool_RemoteChannelBlockedByDefault verifies exec is blocked for remote channels +func TestShellTool_RemoteChannelBlockedByDefault(t *testing.T) { + cfg := &config.Config{} + cfg.Tools.Exec.EnableDenyPatterns = true + cfg.Tools.Exec.AllowRemote = false + + tool, err := NewExecToolWithConfig("", false, cfg) + if err != nil { + t.Fatalf("NewExecToolWithConfig() error: %v", err) + } + ctx := WithToolContext(context.Background(), "telegram", "chat-1") + result := tool.Execute(ctx, map[string]any{"command": "echo hi"}) + + if !result.IsError { + t.Fatal("expected remote-channel exec to be blocked") + } + if !strings.Contains(result.ForLLM, "restricted to internal channels") { + t.Errorf("expected 'restricted to internal channels' message, got: %s", result.ForLLM) + } +} + +// TestShellTool_InternalChannelAllowed verifies exec is allowed for internal channels +func TestShellTool_InternalChannelAllowed(t *testing.T) { + cfg := &config.Config{} + cfg.Tools.Exec.EnableDenyPatterns = true + cfg.Tools.Exec.AllowRemote = false + + tool, err := NewExecToolWithConfig("", false, cfg) + if err != nil { + t.Fatalf("NewExecToolWithConfig() error: %v", err) + } + ctx := WithToolContext(context.Background(), "cli", "direct") + result := tool.Execute(ctx, map[string]any{"command": "echo hi"}) + + if result.IsError { + t.Fatalf("expected internal channel exec to succeed, got: %s", result.ForLLM) + } + if !strings.Contains(result.ForLLM, "hi") { + t.Errorf("expected output to contain 'hi', got: %s", result.ForLLM) + } +} + +// TestShellTool_EmptyChannelBlockedWhenNotAllowRemote verifies fail-closed when no channel context +func TestShellTool_EmptyChannelBlockedWhenNotAllowRemote(t *testing.T) { + cfg := &config.Config{} + cfg.Tools.Exec.EnableDenyPatterns = true + cfg.Tools.Exec.AllowRemote = false + + tool, err := NewExecToolWithConfig("", false, cfg) + if err != nil { + t.Fatalf("NewExecToolWithConfig() error: %v", err) + } + result := tool.Execute(context.Background(), map[string]any{ + "command": "echo hi", + }) + + if !result.IsError { + t.Fatal("expected exec with empty channel to be blocked when allowRemote=false") + } +} + +// TestShellTool_AllowRemoteBypassesChannelCheck verifies allowRemote=true permits any channel +func TestShellTool_AllowRemoteBypassesChannelCheck(t *testing.T) { + cfg := &config.Config{} + cfg.Tools.Exec.EnableDenyPatterns = true + cfg.Tools.Exec.AllowRemote = true + + tool, err := NewExecToolWithConfig("", false, cfg) + if err != nil { + t.Fatalf("NewExecToolWithConfig() error: %v", err) + } + ctx := WithToolContext(context.Background(), "telegram", "chat-1") + result := tool.Execute(ctx, map[string]any{"command": "echo hi"}) + + if result.IsError { + t.Fatalf("expected allowRemote=true to permit remote channel, got: %s", result.ForLLM) + } +} + // TestShellTool_RestrictToWorkspace verifies workspace restriction func TestShellTool_RestrictToWorkspace(t *testing.T) { tmpDir := t.TempDir() diff --git a/pkg/tools/web.go b/pkg/tools/web.go index e248ea966..003cd860c 100644 --- a/pkg/tools/web.go +++ b/pkg/tools/web.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" "io" + "net" "net/http" "net/url" "regexp" @@ -818,6 +819,10 @@ func NewWebFetchTool(maxChars int, fetchLimitBytes int64) (*WebFetchTool, error) return NewWebFetchToolWithProxy(maxChars, "", fetchLimitBytes) } +// allowPrivateWebFetchHosts controls whether loopback/private hosts are allowed. +// This is false in normal runtime to reduce SSRF exposure, and tests can override it temporarily. +var allowPrivateWebFetchHosts atomic.Bool + func NewWebFetchToolWithProxy(maxChars int, proxy string, fetchLimitBytes int64) (*WebFetchTool, error) { if maxChars <= 0 { maxChars = defaultMaxChars @@ -826,10 +831,20 @@ func NewWebFetchToolWithProxy(maxChars int, proxy string, fetchLimitBytes int64) if err != nil { return nil, fmt.Errorf("failed to create HTTP client for web fetch: %w", err) } + if transport, ok := client.Transport.(*http.Transport); ok { + dialer := &net.Dialer{ + Timeout: 15 * time.Second, + KeepAlive: 30 * time.Second, + } + transport.DialContext = newSafeDialContext(dialer) + } client.CheckRedirect = func(req *http.Request, via []*http.Request) error { if len(via) >= maxRedirects { return fmt.Errorf("stopped after %d redirects", maxRedirects) } + if isObviousPrivateHost(req.URL.Hostname()) { + return fmt.Errorf("redirect target is private or local network host") + } return nil } if fetchLimitBytes <= 0 { @@ -888,6 +903,13 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]any) *ToolRe return ErrorResult("missing domain in URL") } + // Lightweight pre-flight: block obvious localhost/literal-IP without DNS resolution. + // The real SSRF guard is newSafeDialContext at connect time. + hostname := parsedURL.Hostname() + if isObviousPrivateHost(hostname) { + return ErrorResult("fetching private or local network hosts is not allowed") + } + maxChars := t.maxChars if mc, ok := args["maxChars"].(float64); ok { if int(mc) > 100 { @@ -901,7 +923,6 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]any) *ToolRe } req.Header.Set("User-Agent", userAgent) - resp, err := t.client.Do(req) if err != nil { return ErrorResult(fmt.Sprintf("request failed: %v", err)) @@ -992,3 +1013,127 @@ func (t *WebFetchTool) extractText(htmlContent string) string { return strings.Join(cleanLines, "\n") } + +// newSafeDialContext re-resolves DNS at connect time to mitigate DNS rebinding (TOCTOU) +// where a hostname resolves to a public IP during pre-flight but a private IP at connect time. +func newSafeDialContext(dialer *net.Dialer) func(context.Context, string, string) (net.Conn, error) { + return func(ctx context.Context, network, address string) (net.Conn, error) { + if allowPrivateWebFetchHosts.Load() { + return dialer.DialContext(ctx, network, address) + } + + host, port, err := net.SplitHostPort(address) + if err != nil { + return nil, fmt.Errorf("invalid target address %q: %w", address, err) + } + if host == "" { + return nil, fmt.Errorf("empty target host") + } + + if ip := net.ParseIP(host); ip != nil { + if isPrivateOrRestrictedIP(ip) { + return nil, fmt.Errorf("blocked private or local target: %s", host) + } + return dialer.DialContext(ctx, network, net.JoinHostPort(ip.String(), port)) + } + + ipAddrs, err := net.DefaultResolver.LookupIPAddr(ctx, host) + if err != nil { + return nil, fmt.Errorf("failed to resolve %s: %w", host, err) + } + + attempted := 0 + var lastErr error + for _, ipAddr := range ipAddrs { + if isPrivateOrRestrictedIP(ipAddr.IP) { + continue + } + attempted++ + conn, err := dialer.DialContext(ctx, network, net.JoinHostPort(ipAddr.IP.String(), port)) + if err == nil { + return conn, nil + } + lastErr = err + } + + if attempted == 0 { + return nil, fmt.Errorf("all resolved addresses for %s are private or restricted", host) + } + if lastErr != nil { + return nil, fmt.Errorf("failed connecting to public addresses for %s: %w", host, lastErr) + } + return nil, fmt.Errorf("failed connecting to public addresses for %s", host) + } +} + +// isObviousPrivateHost performs a lightweight, no-DNS check for obviously private hosts. +// It catches localhost, literal private IPs, and empty hosts. It does NOT resolve DNS — +// the real SSRF guard is newSafeDialContext which checks IPs at connect time. +func isObviousPrivateHost(host string) bool { + if allowPrivateWebFetchHosts.Load() { + return false + } + + h := strings.ToLower(strings.TrimSpace(host)) + h = strings.TrimSuffix(h, ".") + if h == "" { + return true + } + + if h == "localhost" || strings.HasSuffix(h, ".localhost") { + return true + } + + if ip := net.ParseIP(h); ip != nil { + return isPrivateOrRestrictedIP(ip) + } + + return false +} + +// isPrivateOrRestrictedIP returns true for IPs that should never be reached via web_fetch: +// RFC 1918, loopback, link-local (incl. cloud metadata 169.254.x.x), carrier-grade NAT, +// IPv6 unique-local (fc00::/7), 6to4 (2002::/16), and Teredo (2001:0000::/32). +func isPrivateOrRestrictedIP(ip net.IP) bool { + if ip == nil { + return true + } + + if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() || + ip.IsMulticast() || ip.IsUnspecified() { + return true + } + + if ip4 := ip.To4(); ip4 != nil { + // IPv4 private, loopback, link-local, and carrier-grade NAT ranges. + if ip4[0] == 10 || + ip4[0] == 127 || + ip4[0] == 0 || + (ip4[0] == 172 && ip4[1] >= 16 && ip4[1] <= 31) || + (ip4[0] == 192 && ip4[1] == 168) || + (ip4[0] == 169 && ip4[1] == 254) || + (ip4[0] == 100 && ip4[1] >= 64 && ip4[1] <= 127) { + return true + } + return false + } + + if len(ip) == net.IPv6len { + // IPv6 unique local addresses (fc00::/7) + if (ip[0] & 0xfe) == 0xfc { + return true + } + // 6to4 addresses (2002::/16): check the embedded IPv4 at bytes [2:6]. + if ip[0] == 0x20 && ip[1] == 0x02 { + embedded := net.IPv4(ip[2], ip[3], ip[4], ip[5]) + return isPrivateOrRestrictedIP(embedded) + } + // Teredo (2001:0000::/32): client IPv4 is at bytes [12:16], XOR-inverted. + if ip[0] == 0x20 && ip[1] == 0x01 && ip[2] == 0x00 && ip[3] == 0x00 { + client := net.IPv4(ip[12]^0xff, ip[13]^0xff, ip[14]^0xff, ip[15]^0xff) + return isPrivateOrRestrictedIP(client) + } + } + + return false +} diff --git a/pkg/tools/web_test.go b/pkg/tools/web_test.go index 188fb8adb..0737d2087 100644 --- a/pkg/tools/web_test.go +++ b/pkg/tools/web_test.go @@ -5,6 +5,7 @@ import ( "context" "encoding/json" "fmt" + "net" "net/http" "net/http/httptest" "strings" @@ -18,6 +19,8 @@ const testFetchLimit = int64(10 * 1024 * 1024) // TestWebTool_WebFetch_Success verifies successful URL fetching func TestWebTool_WebFetch_Success(t *testing.T) { + withPrivateWebFetchHostsAllowed(t) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/html") w.WriteHeader(http.StatusOK) @@ -55,6 +58,8 @@ func TestWebTool_WebFetch_Success(t *testing.T) { // TestWebTool_WebFetch_JSON verifies JSON content handling func TestWebTool_WebFetch_JSON(t *testing.T) { + withPrivateWebFetchHostsAllowed(t) + testData := map[string]string{"key": "value", "number": "123"} expectedJSON, _ := json.MarshalIndent(testData, "", " ") @@ -163,6 +168,8 @@ func TestWebTool_WebFetch_MissingURL(t *testing.T) { // TestWebTool_WebFetch_Truncation verifies content truncation func TestWebTool_WebFetch_Truncation(t *testing.T) { + withPrivateWebFetchHostsAllowed(t) + longContent := strings.Repeat("x", 20000) server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -205,6 +212,8 @@ func TestWebTool_WebFetch_Truncation(t *testing.T) { } func TestWebFetchTool_PayloadTooLarge(t *testing.T) { + withPrivateWebFetchHostsAllowed(t) + // Create a mock HTTP server ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/html") @@ -290,6 +299,8 @@ func TestWebTool_WebSearch_MissingQuery(t *testing.T) { // TestWebTool_WebFetch_HTMLExtraction verifies HTML text extraction func TestWebTool_WebFetch_HTMLExtraction(t *testing.T) { + withPrivateWebFetchHostsAllowed(t) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/html") w.WriteHeader(http.StatusOK) @@ -404,6 +415,205 @@ func TestWebFetchTool_extractText(t *testing.T) { } } +func withPrivateWebFetchHostsAllowed(t *testing.T) { + t.Helper() + previous := allowPrivateWebFetchHosts.Load() + allowPrivateWebFetchHosts.Store(true) + t.Cleanup(func() { + allowPrivateWebFetchHosts.Store(previous) + }) +} + +func TestWebTool_WebFetch_PrivateHostBlocked(t *testing.T) { + tool, err := NewWebFetchTool(50000, testFetchLimit) + if err != nil { + t.Fatalf("Failed to create web fetch tool: %v", err) + } + result := tool.Execute(context.Background(), map[string]any{ + "url": "http://127.0.0.1:0", + }) + + if !result.IsError { + t.Errorf("expected error for private host URL, got success") + } + if !strings.Contains(result.ForLLM, "private or local network") && + !strings.Contains(result.ForUser, "private or local network") { + t.Errorf("expected private host block message, got %q", result.ForLLM) + } +} + +func TestWebTool_WebFetch_PrivateHostAllowedForTests(t *testing.T) { + withPrivateWebFetchHostsAllowed(t) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/plain") + w.WriteHeader(http.StatusOK) + w.Write([]byte("ok")) + })) + defer server.Close() + + tool, err := NewWebFetchTool(50000, testFetchLimit) + if err != nil { + t.Fatalf("Failed to create web fetch tool: %v", err) + } + result := tool.Execute(context.Background(), map[string]any{ + "url": server.URL, + }) + + if result.IsError { + t.Errorf("expected success when private host access is allowed in tests, got %q", result.ForLLM) + } +} + +// TestWebFetch_BlocksIPv4MappedIPv6Loopback verifies ::ffff:127.0.0.1 is blocked +func TestWebFetch_BlocksIPv4MappedIPv6Loopback(t *testing.T) { + tool, err := NewWebFetchTool(50000, testFetchLimit) + if err != nil { + t.Fatalf("Failed to create web fetch tool: %v", err) + } + result := tool.Execute(context.Background(), map[string]any{ + "url": "http://[::ffff:127.0.0.1]:0", + }) + + if !result.IsError { + t.Error("expected error for IPv4-mapped IPv6 loopback URL, got success") + } +} + +// TestWebFetch_BlocksMetadataIP verifies 169.254.169.254 is blocked +func TestWebFetch_BlocksMetadataIP(t *testing.T) { + tool, err := NewWebFetchTool(50000, testFetchLimit) + if err != nil { + t.Fatalf("Failed to create web fetch tool: %v", err) + } + result := tool.Execute(context.Background(), map[string]any{ + "url": "http://169.254.169.254/latest/meta-data", + }) + + if !result.IsError { + t.Error("expected error for cloud metadata IP, got success") + } +} + +// TestWebFetch_BlocksIPv6UniqueLocal verifies fc00::/7 addresses are blocked +func TestWebFetch_BlocksIPv6UniqueLocal(t *testing.T) { + tool, err := NewWebFetchTool(50000, testFetchLimit) + if err != nil { + t.Fatalf("Failed to create web fetch tool: %v", err) + } + result := tool.Execute(context.Background(), map[string]any{ + "url": "http://[fd00::1]:0", + }) + + if !result.IsError { + t.Error("expected error for IPv6 unique local address, got success") + } +} + +// TestWebFetch_Blocks6to4WithPrivateEmbed verifies 6to4 with private embedded IPv4 is blocked +func TestWebFetch_Blocks6to4WithPrivateEmbed(t *testing.T) { + tool, err := NewWebFetchTool(50000, testFetchLimit) + if err != nil { + t.Fatalf("Failed to create web fetch tool: %v", err) + } + // 2002:7f00:0001::1 embeds 127.0.0.1 + result := tool.Execute(context.Background(), map[string]any{ + "url": "http://[2002:7f00:0001::1]:0", + }) + + if !result.IsError { + t.Error("expected error for 6to4 with private embedded IPv4, got success") + } +} + +// TestWebFetch_Allows6to4WithPublicEmbed verifies 6to4 with public embedded IPv4 is NOT blocked +func TestWebFetch_Allows6to4WithPublicEmbed(t *testing.T) { + tool, err := NewWebFetchTool(50000, testFetchLimit) + if err != nil { + t.Fatalf("Failed to create web fetch tool: %v", err) + } + // 2002:0801:0101::1 embeds 8.1.1.1 (public) — pre-flight should pass, + // connection will fail (no listener) but that's after the SSRF check. + result := tool.Execute(context.Background(), map[string]any{ + "url": "http://[2002:0801:0101::1]:0", + }) + + // Should NOT be blocked by SSRF check — error should be connection failure, not "private" + if result.IsError && strings.Contains(result.ForLLM, "private") { + t.Error("6to4 with public embedded IPv4 should not be blocked as private") + } +} + +// TestWebFetch_RedirectToPrivateBlocked verifies redirects to private IPs are blocked +func TestWebFetch_RedirectToPrivateBlocked(t *testing.T) { + withPrivateWebFetchHostsAllowed(t) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Redirect to a private IP + http.Redirect(w, r, "http://10.0.0.1/secret", http.StatusFound) + })) + defer server.Close() + + // Temporarily disable private host allowance for the redirect check + allowPrivateWebFetchHosts.Store(false) + defer allowPrivateWebFetchHosts.Store(true) + + tool, err := NewWebFetchTool(50000, testFetchLimit) + if err != nil { + t.Fatalf("Failed to create web fetch tool: %v", err) + } + result := tool.Execute(context.Background(), map[string]any{ + "url": server.URL, + }) + + if !result.IsError { + t.Error("expected error when redirecting to private IP, got success") + } +} + +// TestIsPrivateOrRestrictedIP_Table tests IP classification logic +func TestIsPrivateOrRestrictedIP_Table(t *testing.T) { + tests := []struct { + ip string + blocked bool + desc string + }{ + {"127.0.0.1", true, "IPv4 loopback"}, + {"10.0.0.1", true, "IPv4 private class A"}, + {"172.16.0.1", true, "IPv4 private class B"}, + {"192.168.1.1", true, "IPv4 private class C"}, + {"169.254.169.254", true, "link-local / cloud metadata"}, + {"100.64.0.1", true, "carrier-grade NAT"}, + {"0.0.0.0", true, "unspecified"}, + {"8.8.8.8", false, "public DNS"}, + {"1.1.1.1", false, "public DNS"}, + {"::1", true, "IPv6 loopback"}, + {"::ffff:127.0.0.1", true, "IPv4-mapped IPv6 loopback"}, + {"::ffff:10.0.0.1", true, "IPv4-mapped IPv6 private"}, + {"fc00::1", true, "IPv6 unique local"}, + {"fd00::1", true, "IPv6 unique local"}, + {"2002:7f00:0001::1", true, "6to4 with embedded 127.x (private)"}, + {"2002:0a00:0001::1", true, "6to4 with embedded 10.0.0.1 (private)"}, + {"2002:0801:0101::1", false, "6to4 with embedded 8.1.1.1 (public)"}, + {"2001:0000:4136:e378:8000:63bf:f5ff:fffe", true, "Teredo with client 10.0.0.1 (private)"}, + {"2001:0000:4136:e378:8000:63bf:f7f6:fefe", false, "Teredo with client 8.9.1.1 (public)"}, + {"2607:f8b0:4004:800::200e", false, "public IPv6 (Google)"}, + } + + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + ip := net.ParseIP(tt.ip) + if ip == nil { + t.Fatalf("failed to parse IP: %s", tt.ip) + } + got := isPrivateOrRestrictedIP(ip) + if got != tt.blocked { + t.Errorf("isPrivateOrRestrictedIP(%s) = %v, want %v", tt.ip, got, tt.blocked) + } + }) + } +} + // TestWebTool_WebFetch_MissingDomain verifies error handling for URL without domain func TestWebTool_WebFetch_MissingDomain(t *testing.T) { tool, err := NewWebFetchTool(50000, testFetchLimit)