From b4468313e4510c4b68abb61e75be74d284e730ab Mon Sep 17 00:00:00 2001 From: Alix-007 Date: Tue, 17 Mar 2026 23:22:05 +0800 Subject: [PATCH] feat(web): whitelist private fetch targets (#1688) * feat(web): whitelist private fetch targets * test(web): avoid accept error shadowing --------- Co-authored-by: Alix-007 <267018309+Alix-007@users.noreply.github.com> --- config/config.example.json | 3 +- pkg/agent/loop.go | 7 +- pkg/config/config.go | 5 +- pkg/tools/web.go | 105 +++++++++++++++++++++++--- pkg/tools/web_test.go | 147 +++++++++++++++++++++++++++++++++++++ 5 files changed, 253 insertions(+), 14 deletions(-) diff --git a/config/config.example.json b/config/config.example.json index 14e209259..f05a09ef9 100644 --- a/config/config.example.json +++ b/config/config.example.json @@ -351,7 +351,8 @@ "search_engine": "search_std", "max_results": 5 }, - "fetch_limit_bytes": 10485760 + "fetch_limit_bytes": 10485760, + "private_host_whitelist": [] }, "cron": { "enabled": true, diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index 8328c691e..c25650201 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -159,7 +159,12 @@ func registerSharedTools( } } if cfg.Tools.IsToolEnabled("web_fetch") { - fetchTool, err := tools.NewWebFetchToolWithProxy(50000, cfg.Tools.Web.Proxy, cfg.Tools.Web.FetchLimitBytes) + fetchTool, err := tools.NewWebFetchToolWithConfig( + 50000, + cfg.Tools.Web.Proxy, + cfg.Tools.Web.FetchLimitBytes, + cfg.Tools.Web.PrivateHostWhitelist, + ) if err != nil { logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()}) } else { diff --git a/pkg/config/config.go b/pkg/config/config.go index 6694ef3a1..005e44a30 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -695,8 +695,9 @@ type WebToolsConfig struct { GLMSearch GLMSearchConfig ` json:"glm_search"` // Proxy is an optional proxy URL for web tools (http/https/socks5/socks5h). // For authenticated proxies, prefer HTTP_PROXY/HTTPS_PROXY env vars instead of embedding credentials in config. - Proxy string `json:"proxy,omitempty" env:"PICOCLAW_TOOLS_WEB_PROXY"` - FetchLimitBytes int64 `json:"fetch_limit_bytes,omitempty" env:"PICOCLAW_TOOLS_WEB_FETCH_LIMIT_BYTES"` + Proxy string `json:"proxy,omitempty" env:"PICOCLAW_TOOLS_WEB_PROXY"` + FetchLimitBytes int64 `json:"fetch_limit_bytes,omitempty" env:"PICOCLAW_TOOLS_WEB_FETCH_LIMIT_BYTES"` + PrivateHostWhitelist FlexibleStringSlice `json:"private_host_whitelist,omitempty" env:"PICOCLAW_TOOLS_WEB_PRIVATE_HOST_WHITELIST"` } type CronToolsConfig struct { diff --git a/pkg/tools/web.go b/pkg/tools/web.go index e5036d3a8..9ed2140cc 100644 --- a/pkg/tools/web.go +++ b/pkg/tools/web.go @@ -777,11 +777,17 @@ type WebFetchTool struct { proxy string client *http.Client fetchLimitBytes int64 + whitelist *privateHostWhitelist +} + +type privateHostWhitelist struct { + exact map[string]struct{} + cidrs []*net.IPNet } func NewWebFetchTool(maxChars int, fetchLimitBytes int64) (*WebFetchTool, error) { // createHTTPClient cannot fail with an empty proxy string. - return NewWebFetchToolWithProxy(maxChars, "", fetchLimitBytes) + return NewWebFetchToolWithConfig(maxChars, "", fetchLimitBytes, nil) } // allowPrivateWebFetchHosts controls whether loopback/private hosts are allowed. @@ -789,9 +795,22 @@ func NewWebFetchTool(maxChars int, fetchLimitBytes int64) (*WebFetchTool, error) var allowPrivateWebFetchHosts atomic.Bool func NewWebFetchToolWithProxy(maxChars int, proxy string, fetchLimitBytes int64) (*WebFetchTool, error) { + return NewWebFetchToolWithConfig(maxChars, proxy, fetchLimitBytes, nil) +} + +func NewWebFetchToolWithConfig( + maxChars int, + proxy string, + fetchLimitBytes int64, + privateHostWhitelist []string, +) (*WebFetchTool, error) { if maxChars <= 0 { maxChars = defaultMaxChars } + whitelist, err := newPrivateHostWhitelist(privateHostWhitelist) + if err != nil { + return nil, fmt.Errorf("failed to parse web fetch private host whitelist: %w", err) + } client, err := utils.CreateHTTPClient(proxy, fetchTimeout) if err != nil { return nil, fmt.Errorf("failed to create HTTP client for web fetch: %w", err) @@ -801,13 +820,13 @@ func NewWebFetchToolWithProxy(maxChars int, proxy string, fetchLimitBytes int64) Timeout: 15 * time.Second, KeepAlive: 30 * time.Second, } - transport.DialContext = newSafeDialContext(dialer) + transport.DialContext = newSafeDialContext(dialer, whitelist) } client.CheckRedirect = func(req *http.Request, via []*http.Request) error { if len(via) >= maxRedirects { return fmt.Errorf("stopped after %d redirects", maxRedirects) } - if isObviousPrivateHost(req.URL.Hostname()) { + if isObviousPrivateHost(req.URL.Hostname(), whitelist) { return fmt.Errorf("redirect target is private or local network host") } return nil @@ -820,6 +839,7 @@ func NewWebFetchToolWithProxy(maxChars int, proxy string, fetchLimitBytes int64) proxy: proxy, client: client, fetchLimitBytes: fetchLimitBytes, + whitelist: whitelist, }, nil } @@ -871,7 +891,7 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]any) *ToolRe // Lightweight pre-flight: block obvious localhost/literal-IP without DNS resolution. // The real SSRF guard is newSafeDialContext at connect time. hostname := parsedURL.Hostname() - if isObviousPrivateHost(hostname) { + if isObviousPrivateHost(hostname, t.whitelist) { return ErrorResult("fetching private or local network hosts is not allowed") } @@ -981,7 +1001,10 @@ func (t *WebFetchTool) extractText(htmlContent string) string { // newSafeDialContext re-resolves DNS at connect time to mitigate DNS rebinding (TOCTOU) // where a hostname resolves to a public IP during pre-flight but a private IP at connect time. -func newSafeDialContext(dialer *net.Dialer) func(context.Context, string, string) (net.Conn, error) { +func newSafeDialContext( + dialer *net.Dialer, + whitelist *privateHostWhitelist, +) func(context.Context, string, string) (net.Conn, error) { return func(ctx context.Context, network, address string) (net.Conn, error) { if allowPrivateWebFetchHosts.Load() { return dialer.DialContext(ctx, network, address) @@ -996,7 +1019,7 @@ func newSafeDialContext(dialer *net.Dialer) func(context.Context, string, string } if ip := net.ParseIP(host); ip != nil { - if isPrivateOrRestrictedIP(ip) { + if shouldBlockPrivateIP(ip, whitelist) { return nil, fmt.Errorf("blocked private or local target: %s", host) } return dialer.DialContext(ctx, network, net.JoinHostPort(ip.String(), port)) @@ -1010,7 +1033,7 @@ func newSafeDialContext(dialer *net.Dialer) func(context.Context, string, string attempted := 0 var lastErr error for _, ipAddr := range ipAddrs { - if isPrivateOrRestrictedIP(ipAddr.IP) { + if shouldBlockPrivateIP(ipAddr.IP, whitelist) { continue } attempted++ @@ -1022,7 +1045,7 @@ func newSafeDialContext(dialer *net.Dialer) func(context.Context, string, string } if attempted == 0 { - return nil, fmt.Errorf("all resolved addresses for %s are private or restricted", host) + return nil, fmt.Errorf("all resolved addresses for %s are private, restricted, or not whitelisted", host) } if lastErr != nil { return nil, fmt.Errorf("failed connecting to public addresses for %s: %w", host, lastErr) @@ -1031,10 +1054,72 @@ func newSafeDialContext(dialer *net.Dialer) func(context.Context, string, string } } +func newPrivateHostWhitelist(entries []string) (*privateHostWhitelist, error) { + if len(entries) == 0 { + return nil, nil + } + + whitelist := &privateHostWhitelist{ + exact: make(map[string]struct{}), + cidrs: make([]*net.IPNet, 0, len(entries)), + } + for _, entry := range entries { + entry = strings.TrimSpace(entry) + if entry == "" { + continue + } + if ip := net.ParseIP(entry); ip != nil { + whitelist.exact[normalizeWhitelistIP(ip).String()] = struct{}{} + continue + } + _, network, err := net.ParseCIDR(entry) + if err != nil { + return nil, fmt.Errorf("invalid entry %q: expected IP or CIDR", entry) + } + whitelist.cidrs = append(whitelist.cidrs, network) + } + + if len(whitelist.exact) == 0 && len(whitelist.cidrs) == 0 { + return nil, nil + } + return whitelist, nil +} + +func (w *privateHostWhitelist) Contains(ip net.IP) bool { + if w == nil || ip == nil { + return false + } + + normalized := normalizeWhitelistIP(ip) + if _, ok := w.exact[normalized.String()]; ok { + return true + } + for _, network := range w.cidrs { + if network.Contains(normalized) { + return true + } + } + return false +} + +func normalizeWhitelistIP(ip net.IP) net.IP { + if ip == nil { + return nil + } + if ip4 := ip.To4(); ip4 != nil { + return ip4 + } + return ip +} + +func shouldBlockPrivateIP(ip net.IP, whitelist *privateHostWhitelist) bool { + return isPrivateOrRestrictedIP(ip) && !whitelist.Contains(ip) +} + // isObviousPrivateHost performs a lightweight, no-DNS check for obviously private hosts. // It catches localhost, literal private IPs, and empty hosts. It does NOT resolve DNS — // the real SSRF guard is newSafeDialContext which checks IPs at connect time. -func isObviousPrivateHost(host string) bool { +func isObviousPrivateHost(host string, whitelist *privateHostWhitelist) bool { if allowPrivateWebFetchHosts.Load() { return false } @@ -1050,7 +1135,7 @@ func isObviousPrivateHost(host string) bool { } if ip := net.ParseIP(h); ip != nil { - return isPrivateOrRestrictedIP(ip) + return shouldBlockPrivateIP(ip, whitelist) } return false diff --git a/pkg/tools/web_test.go b/pkg/tools/web_test.go index 41d83e6f5..80c9a2067 100644 --- a/pkg/tools/web_test.go +++ b/pkg/tools/web_test.go @@ -10,6 +10,7 @@ import ( "net/http/httptest" "strings" "testing" + "time" "github.com/sipeed/picoclaw/pkg/logger" ) @@ -423,6 +424,29 @@ func withPrivateWebFetchHostsAllowed(t *testing.T) { }) } +func serverHostAndPort(t *testing.T, rawURL string) (string, string) { + t.Helper() + hostPort := strings.TrimPrefix(rawURL, "http://") + hostPort = strings.TrimPrefix(hostPort, "https://") + host, port, err := net.SplitHostPort(hostPort) + if err != nil { + t.Fatalf("failed to split host/port from %q: %v", rawURL, err) + } + return host, port +} + +func singleHostCIDR(t *testing.T, host string) string { + t.Helper() + ip := net.ParseIP(host) + if ip == nil { + t.Fatalf("failed to parse IP %q", host) + } + if ip.To4() != nil { + return ip.String() + "/32" + } + return ip.String() + "/128" +} + func TestWebTool_WebFetch_PrivateHostBlocked(t *testing.T) { tool, err := NewWebFetchTool(50000, testFetchLimit) if err != nil { @@ -441,6 +465,56 @@ func TestWebTool_WebFetch_PrivateHostBlocked(t *testing.T) { } } +func TestWebTool_WebFetch_PrivateHostAllowedByExactWhitelist(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/plain") + w.WriteHeader(http.StatusOK) + w.Write([]byte("exact whitelist ok")) + })) + defer server.Close() + + host, _ := serverHostAndPort(t, server.URL) + tool, err := NewWebFetchToolWithConfig(50000, "", testFetchLimit, []string{host}) + if err != nil { + t.Fatalf("Failed to create web fetch tool: %v", err) + } + + result := tool.Execute(context.Background(), map[string]any{ + "url": server.URL, + }) + if result.IsError { + t.Fatalf("expected success for exact whitelisted private IP, got %q", result.ForLLM) + } + if !strings.Contains(result.ForLLM, "exact whitelist ok") { + t.Fatalf("expected fetched content, got %q", result.ForLLM) + } +} + +func TestWebTool_WebFetch_PrivateHostAllowedByCIDRWhitelist(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/plain") + w.WriteHeader(http.StatusOK) + w.Write([]byte("cidr whitelist ok")) + })) + defer server.Close() + + host, _ := serverHostAndPort(t, server.URL) + tool, err := NewWebFetchToolWithConfig(50000, "", testFetchLimit, []string{singleHostCIDR(t, host)}) + if err != nil { + t.Fatalf("Failed to create web fetch tool: %v", err) + } + + result := tool.Execute(context.Background(), map[string]any{ + "url": server.URL, + }) + if result.IsError { + t.Fatalf("expected success for CIDR-whitelisted private IP, got %q", result.ForLLM) + } + if !strings.Contains(result.ForLLM, "cidr whitelist ok") { + t.Fatalf("expected fetched content, got %q", result.ForLLM) + } +} + func TestWebTool_WebFetch_PrivateHostAllowedForTests(t *testing.T) { withPrivateWebFetchHostsAllowed(t) @@ -570,6 +644,69 @@ func TestWebFetch_RedirectToPrivateBlocked(t *testing.T) { } } +func TestNewSafeDialContext_BlocksPrivateDNSResolutionWithoutWhitelist(t *testing.T) { + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("failed to listen on loopback: %v", err) + } + defer listener.Close() + + _, port, err := net.SplitHostPort(listener.Addr().String()) + if err != nil { + t.Fatalf("failed to split listener address: %v", err) + } + + dialContext := newSafeDialContext(&net.Dialer{Timeout: time.Second}, nil) + _, err = dialContext(context.Background(), "tcp", net.JoinHostPort("localhost", port)) + if err == nil { + t.Fatal("expected localhost DNS resolution to be blocked without whitelist") + } + if !strings.Contains(err.Error(), "private") && !strings.Contains(err.Error(), "whitelisted") { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestNewSafeDialContext_AllowsWhitelistedPrivateDNSResolution(t *testing.T) { + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("failed to listen on loopback: %v", err) + } + defer listener.Close() + + accepted := make(chan struct{}, 1) + go func() { + conn, acceptErr := listener.Accept() + if acceptErr != nil { + return + } + conn.Close() + accepted <- struct{}{} + }() + + _, port, err := net.SplitHostPort(listener.Addr().String()) + if err != nil { + t.Fatalf("failed to split listener address: %v", err) + } + + whitelist, err := newPrivateHostWhitelist([]string{"127.0.0.0/8"}) + if err != nil { + t.Fatalf("failed to parse whitelist: %v", err) + } + + dialContext := newSafeDialContext(&net.Dialer{Timeout: time.Second}, whitelist) + conn, err := dialContext(context.Background(), "tcp", net.JoinHostPort("localhost", port)) + if err != nil { + t.Fatalf("expected localhost DNS resolution to succeed with whitelist, got %v", err) + } + conn.Close() + + select { + case <-accepted: + case <-time.After(time.Second): + t.Fatal("expected localhost listener to accept a connection") + } +} + // TestIsPrivateOrRestrictedIP_Table tests IP classification logic func TestIsPrivateOrRestrictedIP_Table(t *testing.T) { tests := []struct { @@ -660,6 +797,16 @@ func TestNewWebFetchToolWithProxy(t *testing.T) { } } +func TestNewWebFetchToolWithConfig_InvalidPrivateHostWhitelist(t *testing.T) { + _, err := NewWebFetchToolWithConfig(1024, "", testFetchLimit, []string{"not-an-ip-or-cidr"}) + if err == nil { + t.Fatal("expected invalid whitelist entry to fail") + } + if !strings.Contains(err.Error(), "invalid entry") { + t.Fatalf("unexpected error: %v", err) + } +} + func TestNewWebSearchTool_PropagatesProxy(t *testing.T) { t.Run("perplexity", func(t *testing.T) { tool, err := NewWebSearchTool(WebSearchToolOptions{