fix(security): harden unauthenticated tool-exec paths (#1360)

* fix(security): harden unauthenticated tool-exec paths (GHSA-pv8c-p6jf-3fpp)

- Exec tool: channel-based access control (default deny remote)
- Cron tool: command scheduling restricted to internal channels
- Web fetch: SSRF defense-in-depth (pre-flight + dial-time + redirect checks)
- File permissions: session/state dirs 0700, files 0600
- Registry: inject __channel/__chat_id into tool args (replaces racy SetContext)

28 new security regression tests.

(cherry picked from commit 191446ae19021604d3d5b0d9376b9655ab749105)

* fix(exec): revalidate working_dir before command start

* test(web): allow local oversized payload fixture

---------

Co-authored-by: xj <gh-xj@users.noreply.github.com>
This commit is contained in:
wenjie
2026-03-11 19:22:20 +08:00
committed by GitHub
parent dea06c391c
commit 8c2a9332c6
14 changed files with 622 additions and 30 deletions
+3 -3
View File
@@ -209,7 +209,7 @@ func TestWeComAppVerifySignature(t *testing.T) {
}
})
t.Run("empty token skips verification", func(t *testing.T) {
t.Run("empty token rejects verification (fail-closed)", func(t *testing.T) {
cfgEmpty := config.WeComAppConfig{
CorpID: "test_corp_id",
CorpSecret: "test_secret",
@@ -218,8 +218,8 @@ func TestWeComAppVerifySignature(t *testing.T) {
}
chEmpty, _ := NewWeComAppChannel(cfgEmpty, msgBus)
if !verifySignature(chEmpty.config.Token, "any_sig", "any_ts", "any_nonce", "any_msg") {
t.Error("empty token should skip verification and return true")
if verifySignature(chEmpty.config.Token, "any_sig", "any_ts", "any_nonce", "any_msg") {
t.Error("empty token should reject verification (fail-closed)")
}
})
}
+3 -4
View File
@@ -189,8 +189,7 @@ func TestWeComBotVerifySignature(t *testing.T) {
}
})
t.Run("empty token skips verification", func(t *testing.T) {
// Create a channel manually with empty token to test the behavior
t.Run("empty token rejects verification (fail-closed)", func(t *testing.T) {
cfgEmpty := config.WeComConfig{
Token: "",
WebhookURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test",
@@ -199,8 +198,8 @@ func TestWeComBotVerifySignature(t *testing.T) {
config: cfgEmpty,
}
if !verifySignature(chEmpty.config.Token, "any_sig", "any_ts", "any_nonce", "any_msg") {
t.Error("empty token should skip verification and return true")
if verifySignature(chEmpty.config.Token, "any_sig", "any_ts", "any_nonce", "any_msg") {
t.Error("empty token should reject verification (fail-closed)")
}
})
}
+1 -1
View File
@@ -31,7 +31,7 @@ func computeSignature(token, timestamp, nonce, encrypt string) string {
// This is a common function used by both WeCom Bot and WeCom App
func verifySignature(token, msgSignature, timestamp, nonce, msgEncrypt string) bool {
if token == "" {
return true // Skip verification if token is not set
return false
}
return computeSignature(token, timestamp, nonce, msgEncrypt) == msgSignature
}
+1
View File
@@ -673,6 +673,7 @@ type CronToolsConfig struct {
type ExecConfig struct {
ToolConfig ` envPrefix:"PICOCLAW_TOOLS_EXEC_"`
EnableDenyPatterns bool ` env:"PICOCLAW_TOOLS_EXEC_ENABLE_DENY_PATTERNS" json:"enable_deny_patterns"`
AllowRemote bool ` env:"PICOCLAW_TOOLS_EXEC_ALLOW_REMOTE" json:"allow_remote"`
CustomDenyPatterns []string ` env:"PICOCLAW_TOOLS_EXEC_CUSTOM_DENY_PATTERNS" json:"custom_deny_patterns"`
CustomAllowPatterns []string ` env:"PICOCLAW_TOOLS_EXEC_CUSTOM_ALLOW_PATTERNS" json:"custom_allow_patterns"`
TimeoutSeconds int ` env:"PICOCLAW_TOOLS_EXEC_TIMEOUT_SECONDS" json:"timeout_seconds"` // 0 means use default (60s)
+1
View File
@@ -427,6 +427,7 @@ func DefaultConfig() *Config {
Enabled: true,
},
EnableDenyPatterns: true,
AllowRemote: false,
TimeoutSeconds: 60,
},
Skills: SkillsToolsConfig{
+2 -2
View File
@@ -32,7 +32,7 @@ func NewSessionManager(storage string) *SessionManager {
}
if storage != "" {
os.MkdirAll(storage, 0o755)
os.MkdirAll(storage, 0o700)
sm.loadSessions()
}
@@ -216,7 +216,7 @@ func (sm *SessionManager) Save(key string) error {
_ = tmpFile.Close()
return err
}
if err := tmpFile.Chmod(0o644); err != nil {
if err := tmpFile.Chmod(0o600); err != nil {
_ = tmpFile.Close()
return err
}
+2 -2
View File
@@ -40,8 +40,8 @@ func NewManager(workspace string) *Manager {
oldStateFile := filepath.Join(workspace, "state.json")
// Create state directory if it doesn't exist
if err := os.MkdirAll(stateDir, 0o755); err != nil {
log.Fatalf("[FATAL] state: failed to create state directory: %v", err)
if err := os.MkdirAll(stateDir, 0o700); err != nil {
log.Printf("[WARN] state: failed to create state directory %s: %v", stateDir, err)
}
sm := &Manager{
+4 -12
View File
@@ -2,7 +2,6 @@ package state
import (
"encoding/json"
"errors"
"fmt"
"os"
"os/exec"
@@ -217,10 +216,7 @@ func TestNewManager_EmptyWorkspace(t *testing.T) {
}
}
func TestNewManager_MkdirFailureCrashes(t *testing.T) {
// Since log.Fatalf calls os.Exit(1), we cannot test it normally
// Otherwise, the test suite would stop altogether.
// We use the standard pattern of Go: rerun this test in a subprocess.
func TestNewManager_MkdirFailureDoesNotCrash(t *testing.T) {
if os.Getenv("BE_CRASHER") == "1" {
tmpDir := os.Getenv("CRASH_DIR")
@@ -240,15 +236,11 @@ func TestNewManager_MkdirFailureCrashes(t *testing.T) {
}
defer os.RemoveAll(tmpDir)
cmd := exec.Command(os.Args[0], "-test.run=TestNewManager_MkdirFailureCrashes")
cmd := exec.Command(os.Args[0], "-test.run=TestNewManager_MkdirFailureDoesNotCrash")
cmd.Env = append(os.Environ(), "BE_CRASHER=1", "CRASH_DIR="+tmpDir)
err = cmd.Run()
var e *exec.ExitError
if errors.As(err, &e) && !e.Success() {
return
if err != nil {
t.Fatalf("NewManager should not crash when state dir creation fails, got: %v", err)
}
t.Fatalf("The process ended without error, a crash was expected via os.Exit(1). Err: %v", err)
}
+17 -5
View File
@@ -8,6 +8,7 @@ import (
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/constants"
"github.com/sipeed/picoclaw/pkg/cron"
"github.com/sipeed/picoclaw/pkg/utils"
)
@@ -73,6 +74,10 @@ func (t *CronTool) Parameters() map[string]any {
"type": "string",
"description": "Optional: Shell command to execute directly (e.g., 'df -h'). If set, the agent will run this command and report output instead of just showing the message. 'deliver' will be forced to false for commands.",
},
"command_confirm": map[string]any{
"type": "boolean",
"description": "Required when using command=true. Must be true to explicitly confirm scheduling a shell command.",
},
"at_seconds": map[string]any{
"type": "integer",
"description": "One-time reminder: seconds from now when to trigger (e.g., 600 for 10 minutes later). Use this for one-time reminders like 'remind me in 10 minutes'.",
@@ -175,12 +180,17 @@ func (t *CronTool) addJob(ctx context.Context, args map[string]any) *ToolResult
deliver = d
}
// GHSA-pv8c-p6jf-3fpp: command scheduling requires internal channel + explicit confirm.
// Non-command reminders (plain messages) remain open to all channels.
command, _ := args["command"].(string)
commandConfirm, _ := args["command_confirm"].(bool)
if command != "" {
// Commands must be processed by agent/exec tool, so deliver must be false (or handled specifically)
// Actually, let's keep deliver=false to let the system know it's not a simple chat message
// But for our new logic in ExecuteJob, we can handle it regardless of deliver flag if Payload.Command is set.
// However, logically, it's not "delivered" to chat directly as is.
if !constants.IsInternalChannel(channel) {
return ErrorResult("scheduling command execution is restricted to internal channels")
}
if !commandConfirm {
return ErrorResult("command_confirm=true is required to schedule command execution")
}
deliver = false
}
@@ -281,7 +291,9 @@ func (t *CronTool) ExecuteJob(ctx context.Context, job *cron.CronJob) string {
// Execute command if present
if job.Payload.Command != "" {
args := map[string]any{
"command": job.Payload.Command,
"command": job.Payload.Command,
"__channel": channel,
"__chat_id": chatID,
}
result := t.execTool.Execute(ctx, args)
+116
View File
@@ -0,0 +1,116 @@
package tools
import (
"context"
"path/filepath"
"strings"
"testing"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/cron"
)
func newTestCronTool(t *testing.T) *CronTool {
t.Helper()
storePath := filepath.Join(t.TempDir(), "cron.json")
cronService := cron.NewCronService(storePath, nil)
msgBus := bus.NewMessageBus()
cfg := config.DefaultConfig()
tool, err := NewCronTool(cronService, nil, msgBus, t.TempDir(), true, 0, cfg)
if err != nil {
t.Fatalf("NewCronTool() error: %v", err)
}
return tool
}
// TestCronTool_CommandBlockedFromRemoteChannel verifies command scheduling is restricted to internal channels
func TestCronTool_CommandBlockedFromRemoteChannel(t *testing.T) {
tool := newTestCronTool(t)
ctx := WithToolContext(context.Background(), "telegram", "chat-1")
result := tool.Execute(ctx, map[string]any{
"action": "add",
"message": "check disk",
"command": "df -h",
"command_confirm": true,
"at_seconds": float64(60),
})
if !result.IsError {
t.Fatal("expected command scheduling to be blocked from remote channel")
}
if !strings.Contains(result.ForLLM, "restricted to internal channels") {
t.Errorf("expected 'restricted to internal channels', got: %s", result.ForLLM)
}
}
// TestCronTool_CommandRequiresConfirm verifies command_confirm=true is required
func TestCronTool_CommandRequiresConfirm(t *testing.T) {
tool := newTestCronTool(t)
ctx := WithToolContext(context.Background(), "cli", "direct")
result := tool.Execute(ctx, map[string]any{
"action": "add",
"message": "check disk",
"command": "df -h",
"at_seconds": float64(60),
})
if !result.IsError {
t.Fatal("expected error when command_confirm is missing")
}
if !strings.Contains(result.ForLLM, "command_confirm=true") {
t.Errorf("expected 'command_confirm=true' message, got: %s", result.ForLLM)
}
}
// TestCronTool_CommandAllowedFromInternalChannel verifies command scheduling works from internal channels
func TestCronTool_CommandAllowedFromInternalChannel(t *testing.T) {
tool := newTestCronTool(t)
ctx := WithToolContext(context.Background(), "cli", "direct")
result := tool.Execute(ctx, map[string]any{
"action": "add",
"message": "check disk",
"command": "df -h",
"command_confirm": true,
"at_seconds": float64(60),
})
if result.IsError {
t.Fatalf("expected command scheduling to succeed from internal channel, got: %s", result.ForLLM)
}
if !strings.Contains(result.ForLLM, "Cron job added") {
t.Errorf("expected 'Cron job added', got: %s", result.ForLLM)
}
}
// TestCronTool_AddJobRequiresSessionContext verifies fail-closed when channel/chatID missing
func TestCronTool_AddJobRequiresSessionContext(t *testing.T) {
tool := newTestCronTool(t)
result := tool.Execute(context.Background(), map[string]any{
"action": "add",
"message": "reminder",
"at_seconds": float64(60),
})
if !result.IsError {
t.Fatal("expected error when session context is missing")
}
if !strings.Contains(result.ForLLM, "no session context") {
t.Errorf("expected 'no session context' message, got: %s", result.ForLLM)
}
}
// TestCronTool_NonCommandJobAllowedFromRemoteChannel verifies regular reminders work from any channel
func TestCronTool_NonCommandJobAllowedFromRemoteChannel(t *testing.T) {
tool := newTestCronTool(t)
ctx := WithToolContext(context.Background(), "telegram", "chat-1")
result := tool.Execute(ctx, map[string]any{
"action": "add",
"message": "time to stretch",
"at_seconds": float64(600),
})
if result.IsError {
t.Fatalf("expected non-command reminder to succeed from remote channel, got: %s", result.ForLLM)
}
}
+37
View File
@@ -14,6 +14,7 @@ import (
"time"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/constants"
)
type ExecTool struct {
@@ -23,6 +24,7 @@ type ExecTool struct {
allowPatterns []*regexp.Regexp
customAllowPatterns []*regexp.Regexp
restrictToWorkspace bool
allowRemote bool
}
var (
@@ -100,10 +102,12 @@ func NewExecTool(workingDir string, restrict bool) (*ExecTool, error) {
func NewExecToolWithConfig(workingDir string, restrict bool, config *config.Config) (*ExecTool, error) {
denyPatterns := make([]*regexp.Regexp, 0)
customAllowPatterns := make([]*regexp.Regexp, 0)
allowRemote := true
if config != nil {
execConfig := config.Tools.Exec
enableDenyPatterns := execConfig.EnableDenyPatterns
allowRemote = execConfig.AllowRemote
if enableDenyPatterns {
denyPatterns = append(denyPatterns, defaultDenyPatterns...)
if len(execConfig.CustomDenyPatterns) > 0 {
@@ -143,6 +147,7 @@ func NewExecToolWithConfig(workingDir string, restrict bool, config *config.Conf
allowPatterns: nil,
customAllowPatterns: customAllowPatterns,
restrictToWorkspace: restrict,
allowRemote: allowRemote,
}, nil
}
@@ -177,6 +182,19 @@ func (t *ExecTool) Execute(ctx context.Context, args map[string]any) *ToolResult
return ErrorResult("command is required")
}
// GHSA-pv8c-p6jf-3fpp: block exec from remote channels (e.g. Telegram webhooks)
// unless explicitly opted-in via config. Fail-closed: empty channel = blocked.
if !t.allowRemote {
channel := ToolChannel(ctx)
if channel == "" {
channel, _ = args["__channel"].(string)
}
channel = strings.TrimSpace(channel)
if channel == "" || !constants.IsInternalChannel(channel) {
return ErrorResult("exec is restricted to internal channels")
}
}
cwd := t.workingDir
if wd, ok := args["working_dir"].(string); ok && wd != "" {
if t.restrictToWorkspace && t.workingDir != "" {
@@ -201,6 +219,25 @@ func (t *ExecTool) Execute(ctx context.Context, args map[string]any) *ToolResult
return ErrorResult(guardError)
}
// Re-resolve symlinks immediately before execution to shrink the TOCTOU window
// between validation and cmd.Dir assignment.
if t.restrictToWorkspace && t.workingDir != "" && cwd != t.workingDir {
resolved, err := filepath.EvalSymlinks(cwd)
if err != nil {
return ErrorResult(fmt.Sprintf("Command blocked by safety guard (path resolution failed: %v)", err))
}
absWorkspace, _ := filepath.Abs(t.workingDir)
wsResolved, _ := filepath.EvalSymlinks(absWorkspace)
if wsResolved == "" {
wsResolved = absWorkspace
}
rel, err := filepath.Rel(wsResolved, resolved)
if err != nil || !filepath.IsLocal(rel) {
return ErrorResult("Command blocked by safety guard (working directory escaped workspace)")
}
cwd = resolved
}
// timeout == 0 means no timeout
var cmdCtx context.Context
var cancel context.CancelFunc
+79
View File
@@ -301,6 +301,85 @@ func TestShellTool_WorkingDir_SymlinkEscape(t *testing.T) {
}
}
// TestShellTool_RemoteChannelBlockedByDefault verifies exec is blocked for remote channels
func TestShellTool_RemoteChannelBlockedByDefault(t *testing.T) {
cfg := &config.Config{}
cfg.Tools.Exec.EnableDenyPatterns = true
cfg.Tools.Exec.AllowRemote = false
tool, err := NewExecToolWithConfig("", false, cfg)
if err != nil {
t.Fatalf("NewExecToolWithConfig() error: %v", err)
}
ctx := WithToolContext(context.Background(), "telegram", "chat-1")
result := tool.Execute(ctx, map[string]any{"command": "echo hi"})
if !result.IsError {
t.Fatal("expected remote-channel exec to be blocked")
}
if !strings.Contains(result.ForLLM, "restricted to internal channels") {
t.Errorf("expected 'restricted to internal channels' message, got: %s", result.ForLLM)
}
}
// TestShellTool_InternalChannelAllowed verifies exec is allowed for internal channels
func TestShellTool_InternalChannelAllowed(t *testing.T) {
cfg := &config.Config{}
cfg.Tools.Exec.EnableDenyPatterns = true
cfg.Tools.Exec.AllowRemote = false
tool, err := NewExecToolWithConfig("", false, cfg)
if err != nil {
t.Fatalf("NewExecToolWithConfig() error: %v", err)
}
ctx := WithToolContext(context.Background(), "cli", "direct")
result := tool.Execute(ctx, map[string]any{"command": "echo hi"})
if result.IsError {
t.Fatalf("expected internal channel exec to succeed, got: %s", result.ForLLM)
}
if !strings.Contains(result.ForLLM, "hi") {
t.Errorf("expected output to contain 'hi', got: %s", result.ForLLM)
}
}
// TestShellTool_EmptyChannelBlockedWhenNotAllowRemote verifies fail-closed when no channel context
func TestShellTool_EmptyChannelBlockedWhenNotAllowRemote(t *testing.T) {
cfg := &config.Config{}
cfg.Tools.Exec.EnableDenyPatterns = true
cfg.Tools.Exec.AllowRemote = false
tool, err := NewExecToolWithConfig("", false, cfg)
if err != nil {
t.Fatalf("NewExecToolWithConfig() error: %v", err)
}
result := tool.Execute(context.Background(), map[string]any{
"command": "echo hi",
})
if !result.IsError {
t.Fatal("expected exec with empty channel to be blocked when allowRemote=false")
}
}
// TestShellTool_AllowRemoteBypassesChannelCheck verifies allowRemote=true permits any channel
func TestShellTool_AllowRemoteBypassesChannelCheck(t *testing.T) {
cfg := &config.Config{}
cfg.Tools.Exec.EnableDenyPatterns = true
cfg.Tools.Exec.AllowRemote = true
tool, err := NewExecToolWithConfig("", false, cfg)
if err != nil {
t.Fatalf("NewExecToolWithConfig() error: %v", err)
}
ctx := WithToolContext(context.Background(), "telegram", "chat-1")
result := tool.Execute(ctx, map[string]any{"command": "echo hi"})
if result.IsError {
t.Fatalf("expected allowRemote=true to permit remote channel, got: %s", result.ForLLM)
}
}
// TestShellTool_RestrictToWorkspace verifies workspace restriction
func TestShellTool_RestrictToWorkspace(t *testing.T) {
tmpDir := t.TempDir()
+146 -1
View File
@@ -7,6 +7,7 @@ import (
"errors"
"fmt"
"io"
"net"
"net/http"
"net/url"
"regexp"
@@ -818,6 +819,10 @@ func NewWebFetchTool(maxChars int, fetchLimitBytes int64) (*WebFetchTool, error)
return NewWebFetchToolWithProxy(maxChars, "", fetchLimitBytes)
}
// allowPrivateWebFetchHosts controls whether loopback/private hosts are allowed.
// This is false in normal runtime to reduce SSRF exposure, and tests can override it temporarily.
var allowPrivateWebFetchHosts atomic.Bool
func NewWebFetchToolWithProxy(maxChars int, proxy string, fetchLimitBytes int64) (*WebFetchTool, error) {
if maxChars <= 0 {
maxChars = defaultMaxChars
@@ -826,10 +831,20 @@ func NewWebFetchToolWithProxy(maxChars int, proxy string, fetchLimitBytes int64)
if err != nil {
return nil, fmt.Errorf("failed to create HTTP client for web fetch: %w", err)
}
if transport, ok := client.Transport.(*http.Transport); ok {
dialer := &net.Dialer{
Timeout: 15 * time.Second,
KeepAlive: 30 * time.Second,
}
transport.DialContext = newSafeDialContext(dialer)
}
client.CheckRedirect = func(req *http.Request, via []*http.Request) error {
if len(via) >= maxRedirects {
return fmt.Errorf("stopped after %d redirects", maxRedirects)
}
if isObviousPrivateHost(req.URL.Hostname()) {
return fmt.Errorf("redirect target is private or local network host")
}
return nil
}
if fetchLimitBytes <= 0 {
@@ -888,6 +903,13 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]any) *ToolRe
return ErrorResult("missing domain in URL")
}
// Lightweight pre-flight: block obvious localhost/literal-IP without DNS resolution.
// The real SSRF guard is newSafeDialContext at connect time.
hostname := parsedURL.Hostname()
if isObviousPrivateHost(hostname) {
return ErrorResult("fetching private or local network hosts is not allowed")
}
maxChars := t.maxChars
if mc, ok := args["maxChars"].(float64); ok {
if int(mc) > 100 {
@@ -901,7 +923,6 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]any) *ToolRe
}
req.Header.Set("User-Agent", userAgent)
resp, err := t.client.Do(req)
if err != nil {
return ErrorResult(fmt.Sprintf("request failed: %v", err))
@@ -992,3 +1013,127 @@ func (t *WebFetchTool) extractText(htmlContent string) string {
return strings.Join(cleanLines, "\n")
}
// newSafeDialContext re-resolves DNS at connect time to mitigate DNS rebinding (TOCTOU)
// where a hostname resolves to a public IP during pre-flight but a private IP at connect time.
func newSafeDialContext(dialer *net.Dialer) func(context.Context, string, string) (net.Conn, error) {
return func(ctx context.Context, network, address string) (net.Conn, error) {
if allowPrivateWebFetchHosts.Load() {
return dialer.DialContext(ctx, network, address)
}
host, port, err := net.SplitHostPort(address)
if err != nil {
return nil, fmt.Errorf("invalid target address %q: %w", address, err)
}
if host == "" {
return nil, fmt.Errorf("empty target host")
}
if ip := net.ParseIP(host); ip != nil {
if isPrivateOrRestrictedIP(ip) {
return nil, fmt.Errorf("blocked private or local target: %s", host)
}
return dialer.DialContext(ctx, network, net.JoinHostPort(ip.String(), port))
}
ipAddrs, err := net.DefaultResolver.LookupIPAddr(ctx, host)
if err != nil {
return nil, fmt.Errorf("failed to resolve %s: %w", host, err)
}
attempted := 0
var lastErr error
for _, ipAddr := range ipAddrs {
if isPrivateOrRestrictedIP(ipAddr.IP) {
continue
}
attempted++
conn, err := dialer.DialContext(ctx, network, net.JoinHostPort(ipAddr.IP.String(), port))
if err == nil {
return conn, nil
}
lastErr = err
}
if attempted == 0 {
return nil, fmt.Errorf("all resolved addresses for %s are private or restricted", host)
}
if lastErr != nil {
return nil, fmt.Errorf("failed connecting to public addresses for %s: %w", host, lastErr)
}
return nil, fmt.Errorf("failed connecting to public addresses for %s", host)
}
}
// isObviousPrivateHost performs a lightweight, no-DNS check for obviously private hosts.
// It catches localhost, literal private IPs, and empty hosts. It does NOT resolve DNS —
// the real SSRF guard is newSafeDialContext which checks IPs at connect time.
func isObviousPrivateHost(host string) bool {
if allowPrivateWebFetchHosts.Load() {
return false
}
h := strings.ToLower(strings.TrimSpace(host))
h = strings.TrimSuffix(h, ".")
if h == "" {
return true
}
if h == "localhost" || strings.HasSuffix(h, ".localhost") {
return true
}
if ip := net.ParseIP(h); ip != nil {
return isPrivateOrRestrictedIP(ip)
}
return false
}
// isPrivateOrRestrictedIP returns true for IPs that should never be reached via web_fetch:
// RFC 1918, loopback, link-local (incl. cloud metadata 169.254.x.x), carrier-grade NAT,
// IPv6 unique-local (fc00::/7), 6to4 (2002::/16), and Teredo (2001:0000::/32).
func isPrivateOrRestrictedIP(ip net.IP) bool {
if ip == nil {
return true
}
if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() ||
ip.IsMulticast() || ip.IsUnspecified() {
return true
}
if ip4 := ip.To4(); ip4 != nil {
// IPv4 private, loopback, link-local, and carrier-grade NAT ranges.
if ip4[0] == 10 ||
ip4[0] == 127 ||
ip4[0] == 0 ||
(ip4[0] == 172 && ip4[1] >= 16 && ip4[1] <= 31) ||
(ip4[0] == 192 && ip4[1] == 168) ||
(ip4[0] == 169 && ip4[1] == 254) ||
(ip4[0] == 100 && ip4[1] >= 64 && ip4[1] <= 127) {
return true
}
return false
}
if len(ip) == net.IPv6len {
// IPv6 unique local addresses (fc00::/7)
if (ip[0] & 0xfe) == 0xfc {
return true
}
// 6to4 addresses (2002::/16): check the embedded IPv4 at bytes [2:6].
if ip[0] == 0x20 && ip[1] == 0x02 {
embedded := net.IPv4(ip[2], ip[3], ip[4], ip[5])
return isPrivateOrRestrictedIP(embedded)
}
// Teredo (2001:0000::/32): client IPv4 is at bytes [12:16], XOR-inverted.
if ip[0] == 0x20 && ip[1] == 0x01 && ip[2] == 0x00 && ip[3] == 0x00 {
client := net.IPv4(ip[12]^0xff, ip[13]^0xff, ip[14]^0xff, ip[15]^0xff)
return isPrivateOrRestrictedIP(client)
}
}
return false
}
+210
View File
@@ -5,6 +5,7 @@ import (
"context"
"encoding/json"
"fmt"
"net"
"net/http"
"net/http/httptest"
"strings"
@@ -18,6 +19,8 @@ const testFetchLimit = int64(10 * 1024 * 1024)
// TestWebTool_WebFetch_Success verifies successful URL fetching
func TestWebTool_WebFetch_Success(t *testing.T) {
withPrivateWebFetchHostsAllowed(t)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/html")
w.WriteHeader(http.StatusOK)
@@ -55,6 +58,8 @@ func TestWebTool_WebFetch_Success(t *testing.T) {
// TestWebTool_WebFetch_JSON verifies JSON content handling
func TestWebTool_WebFetch_JSON(t *testing.T) {
withPrivateWebFetchHostsAllowed(t)
testData := map[string]string{"key": "value", "number": "123"}
expectedJSON, _ := json.MarshalIndent(testData, "", " ")
@@ -163,6 +168,8 @@ func TestWebTool_WebFetch_MissingURL(t *testing.T) {
// TestWebTool_WebFetch_Truncation verifies content truncation
func TestWebTool_WebFetch_Truncation(t *testing.T) {
withPrivateWebFetchHostsAllowed(t)
longContent := strings.Repeat("x", 20000)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -205,6 +212,8 @@ func TestWebTool_WebFetch_Truncation(t *testing.T) {
}
func TestWebFetchTool_PayloadTooLarge(t *testing.T) {
withPrivateWebFetchHostsAllowed(t)
// Create a mock HTTP server
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/html")
@@ -290,6 +299,8 @@ func TestWebTool_WebSearch_MissingQuery(t *testing.T) {
// TestWebTool_WebFetch_HTMLExtraction verifies HTML text extraction
func TestWebTool_WebFetch_HTMLExtraction(t *testing.T) {
withPrivateWebFetchHostsAllowed(t)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/html")
w.WriteHeader(http.StatusOK)
@@ -404,6 +415,205 @@ func TestWebFetchTool_extractText(t *testing.T) {
}
}
func withPrivateWebFetchHostsAllowed(t *testing.T) {
t.Helper()
previous := allowPrivateWebFetchHosts.Load()
allowPrivateWebFetchHosts.Store(true)
t.Cleanup(func() {
allowPrivateWebFetchHosts.Store(previous)
})
}
func TestWebTool_WebFetch_PrivateHostBlocked(t *testing.T) {
tool, err := NewWebFetchTool(50000, testFetchLimit)
if err != nil {
t.Fatalf("Failed to create web fetch tool: %v", err)
}
result := tool.Execute(context.Background(), map[string]any{
"url": "http://127.0.0.1:0",
})
if !result.IsError {
t.Errorf("expected error for private host URL, got success")
}
if !strings.Contains(result.ForLLM, "private or local network") &&
!strings.Contains(result.ForUser, "private or local network") {
t.Errorf("expected private host block message, got %q", result.ForLLM)
}
}
func TestWebTool_WebFetch_PrivateHostAllowedForTests(t *testing.T) {
withPrivateWebFetchHostsAllowed(t)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain")
w.WriteHeader(http.StatusOK)
w.Write([]byte("ok"))
}))
defer server.Close()
tool, err := NewWebFetchTool(50000, testFetchLimit)
if err != nil {
t.Fatalf("Failed to create web fetch tool: %v", err)
}
result := tool.Execute(context.Background(), map[string]any{
"url": server.URL,
})
if result.IsError {
t.Errorf("expected success when private host access is allowed in tests, got %q", result.ForLLM)
}
}
// TestWebFetch_BlocksIPv4MappedIPv6Loopback verifies ::ffff:127.0.0.1 is blocked
func TestWebFetch_BlocksIPv4MappedIPv6Loopback(t *testing.T) {
tool, err := NewWebFetchTool(50000, testFetchLimit)
if err != nil {
t.Fatalf("Failed to create web fetch tool: %v", err)
}
result := tool.Execute(context.Background(), map[string]any{
"url": "http://[::ffff:127.0.0.1]:0",
})
if !result.IsError {
t.Error("expected error for IPv4-mapped IPv6 loopback URL, got success")
}
}
// TestWebFetch_BlocksMetadataIP verifies 169.254.169.254 is blocked
func TestWebFetch_BlocksMetadataIP(t *testing.T) {
tool, err := NewWebFetchTool(50000, testFetchLimit)
if err != nil {
t.Fatalf("Failed to create web fetch tool: %v", err)
}
result := tool.Execute(context.Background(), map[string]any{
"url": "http://169.254.169.254/latest/meta-data",
})
if !result.IsError {
t.Error("expected error for cloud metadata IP, got success")
}
}
// TestWebFetch_BlocksIPv6UniqueLocal verifies fc00::/7 addresses are blocked
func TestWebFetch_BlocksIPv6UniqueLocal(t *testing.T) {
tool, err := NewWebFetchTool(50000, testFetchLimit)
if err != nil {
t.Fatalf("Failed to create web fetch tool: %v", err)
}
result := tool.Execute(context.Background(), map[string]any{
"url": "http://[fd00::1]:0",
})
if !result.IsError {
t.Error("expected error for IPv6 unique local address, got success")
}
}
// TestWebFetch_Blocks6to4WithPrivateEmbed verifies 6to4 with private embedded IPv4 is blocked
func TestWebFetch_Blocks6to4WithPrivateEmbed(t *testing.T) {
tool, err := NewWebFetchTool(50000, testFetchLimit)
if err != nil {
t.Fatalf("Failed to create web fetch tool: %v", err)
}
// 2002:7f00:0001::1 embeds 127.0.0.1
result := tool.Execute(context.Background(), map[string]any{
"url": "http://[2002:7f00:0001::1]:0",
})
if !result.IsError {
t.Error("expected error for 6to4 with private embedded IPv4, got success")
}
}
// TestWebFetch_Allows6to4WithPublicEmbed verifies 6to4 with public embedded IPv4 is NOT blocked
func TestWebFetch_Allows6to4WithPublicEmbed(t *testing.T) {
tool, err := NewWebFetchTool(50000, testFetchLimit)
if err != nil {
t.Fatalf("Failed to create web fetch tool: %v", err)
}
// 2002:0801:0101::1 embeds 8.1.1.1 (public) — pre-flight should pass,
// connection will fail (no listener) but that's after the SSRF check.
result := tool.Execute(context.Background(), map[string]any{
"url": "http://[2002:0801:0101::1]:0",
})
// Should NOT be blocked by SSRF check — error should be connection failure, not "private"
if result.IsError && strings.Contains(result.ForLLM, "private") {
t.Error("6to4 with public embedded IPv4 should not be blocked as private")
}
}
// TestWebFetch_RedirectToPrivateBlocked verifies redirects to private IPs are blocked
func TestWebFetch_RedirectToPrivateBlocked(t *testing.T) {
withPrivateWebFetchHostsAllowed(t)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Redirect to a private IP
http.Redirect(w, r, "http://10.0.0.1/secret", http.StatusFound)
}))
defer server.Close()
// Temporarily disable private host allowance for the redirect check
allowPrivateWebFetchHosts.Store(false)
defer allowPrivateWebFetchHosts.Store(true)
tool, err := NewWebFetchTool(50000, testFetchLimit)
if err != nil {
t.Fatalf("Failed to create web fetch tool: %v", err)
}
result := tool.Execute(context.Background(), map[string]any{
"url": server.URL,
})
if !result.IsError {
t.Error("expected error when redirecting to private IP, got success")
}
}
// TestIsPrivateOrRestrictedIP_Table tests IP classification logic
func TestIsPrivateOrRestrictedIP_Table(t *testing.T) {
tests := []struct {
ip string
blocked bool
desc string
}{
{"127.0.0.1", true, "IPv4 loopback"},
{"10.0.0.1", true, "IPv4 private class A"},
{"172.16.0.1", true, "IPv4 private class B"},
{"192.168.1.1", true, "IPv4 private class C"},
{"169.254.169.254", true, "link-local / cloud metadata"},
{"100.64.0.1", true, "carrier-grade NAT"},
{"0.0.0.0", true, "unspecified"},
{"8.8.8.8", false, "public DNS"},
{"1.1.1.1", false, "public DNS"},
{"::1", true, "IPv6 loopback"},
{"::ffff:127.0.0.1", true, "IPv4-mapped IPv6 loopback"},
{"::ffff:10.0.0.1", true, "IPv4-mapped IPv6 private"},
{"fc00::1", true, "IPv6 unique local"},
{"fd00::1", true, "IPv6 unique local"},
{"2002:7f00:0001::1", true, "6to4 with embedded 127.x (private)"},
{"2002:0a00:0001::1", true, "6to4 with embedded 10.0.0.1 (private)"},
{"2002:0801:0101::1", false, "6to4 with embedded 8.1.1.1 (public)"},
{"2001:0000:4136:e378:8000:63bf:f5ff:fffe", true, "Teredo with client 10.0.0.1 (private)"},
{"2001:0000:4136:e378:8000:63bf:f7f6:fefe", false, "Teredo with client 8.9.1.1 (public)"},
{"2607:f8b0:4004:800::200e", false, "public IPv6 (Google)"},
}
for _, tt := range tests {
t.Run(tt.desc, func(t *testing.T) {
ip := net.ParseIP(tt.ip)
if ip == nil {
t.Fatalf("failed to parse IP: %s", tt.ip)
}
got := isPrivateOrRestrictedIP(ip)
if got != tt.blocked {
t.Errorf("isPrivateOrRestrictedIP(%s) = %v, want %v", tt.ip, got, tt.blocked)
}
})
}
}
// TestWebTool_WebFetch_MissingDomain verifies error handling for URL without domain
func TestWebTool_WebFetch_MissingDomain(t *testing.T) {
tool, err := NewWebFetchTool(50000, testFetchLimit)