feat(web): whitelist private fetch targets (#1688)

* feat(web): whitelist private fetch targets

* test(web): avoid accept error shadowing

---------

Co-authored-by: Alix-007 <267018309+Alix-007@users.noreply.github.com>
This commit is contained in:
Alix-007
2026-03-17 23:22:05 +08:00
committed by GitHub
parent 5bc4fe4dea
commit b4468313e4
5 changed files with 253 additions and 14 deletions
+2 -1
View File
@@ -351,7 +351,8 @@
"search_engine": "search_std",
"max_results": 5
},
"fetch_limit_bytes": 10485760
"fetch_limit_bytes": 10485760,
"private_host_whitelist": []
},
"cron": {
"enabled": true,
+6 -1
View File
@@ -159,7 +159,12 @@ func registerSharedTools(
}
}
if cfg.Tools.IsToolEnabled("web_fetch") {
fetchTool, err := tools.NewWebFetchToolWithProxy(50000, cfg.Tools.Web.Proxy, cfg.Tools.Web.FetchLimitBytes)
fetchTool, err := tools.NewWebFetchToolWithConfig(
50000,
cfg.Tools.Web.Proxy,
cfg.Tools.Web.FetchLimitBytes,
cfg.Tools.Web.PrivateHostWhitelist,
)
if err != nil {
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
} else {
+3 -2
View File
@@ -695,8 +695,9 @@ type WebToolsConfig struct {
GLMSearch GLMSearchConfig ` json:"glm_search"`
// Proxy is an optional proxy URL for web tools (http/https/socks5/socks5h).
// For authenticated proxies, prefer HTTP_PROXY/HTTPS_PROXY env vars instead of embedding credentials in config.
Proxy string `json:"proxy,omitempty" env:"PICOCLAW_TOOLS_WEB_PROXY"`
FetchLimitBytes int64 `json:"fetch_limit_bytes,omitempty" env:"PICOCLAW_TOOLS_WEB_FETCH_LIMIT_BYTES"`
Proxy string `json:"proxy,omitempty" env:"PICOCLAW_TOOLS_WEB_PROXY"`
FetchLimitBytes int64 `json:"fetch_limit_bytes,omitempty" env:"PICOCLAW_TOOLS_WEB_FETCH_LIMIT_BYTES"`
PrivateHostWhitelist FlexibleStringSlice `json:"private_host_whitelist,omitempty" env:"PICOCLAW_TOOLS_WEB_PRIVATE_HOST_WHITELIST"`
}
type CronToolsConfig struct {
+95 -10
View File
@@ -777,11 +777,17 @@ type WebFetchTool struct {
proxy string
client *http.Client
fetchLimitBytes int64
whitelist *privateHostWhitelist
}
type privateHostWhitelist struct {
exact map[string]struct{}
cidrs []*net.IPNet
}
func NewWebFetchTool(maxChars int, fetchLimitBytes int64) (*WebFetchTool, error) {
// createHTTPClient cannot fail with an empty proxy string.
return NewWebFetchToolWithProxy(maxChars, "", fetchLimitBytes)
return NewWebFetchToolWithConfig(maxChars, "", fetchLimitBytes, nil)
}
// allowPrivateWebFetchHosts controls whether loopback/private hosts are allowed.
@@ -789,9 +795,22 @@ func NewWebFetchTool(maxChars int, fetchLimitBytes int64) (*WebFetchTool, error)
var allowPrivateWebFetchHosts atomic.Bool
func NewWebFetchToolWithProxy(maxChars int, proxy string, fetchLimitBytes int64) (*WebFetchTool, error) {
return NewWebFetchToolWithConfig(maxChars, proxy, fetchLimitBytes, nil)
}
func NewWebFetchToolWithConfig(
maxChars int,
proxy string,
fetchLimitBytes int64,
privateHostWhitelist []string,
) (*WebFetchTool, error) {
if maxChars <= 0 {
maxChars = defaultMaxChars
}
whitelist, err := newPrivateHostWhitelist(privateHostWhitelist)
if err != nil {
return nil, fmt.Errorf("failed to parse web fetch private host whitelist: %w", err)
}
client, err := utils.CreateHTTPClient(proxy, fetchTimeout)
if err != nil {
return nil, fmt.Errorf("failed to create HTTP client for web fetch: %w", err)
@@ -801,13 +820,13 @@ func NewWebFetchToolWithProxy(maxChars int, proxy string, fetchLimitBytes int64)
Timeout: 15 * time.Second,
KeepAlive: 30 * time.Second,
}
transport.DialContext = newSafeDialContext(dialer)
transport.DialContext = newSafeDialContext(dialer, whitelist)
}
client.CheckRedirect = func(req *http.Request, via []*http.Request) error {
if len(via) >= maxRedirects {
return fmt.Errorf("stopped after %d redirects", maxRedirects)
}
if isObviousPrivateHost(req.URL.Hostname()) {
if isObviousPrivateHost(req.URL.Hostname(), whitelist) {
return fmt.Errorf("redirect target is private or local network host")
}
return nil
@@ -820,6 +839,7 @@ func NewWebFetchToolWithProxy(maxChars int, proxy string, fetchLimitBytes int64)
proxy: proxy,
client: client,
fetchLimitBytes: fetchLimitBytes,
whitelist: whitelist,
}, nil
}
@@ -871,7 +891,7 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]any) *ToolRe
// Lightweight pre-flight: block obvious localhost/literal-IP without DNS resolution.
// The real SSRF guard is newSafeDialContext at connect time.
hostname := parsedURL.Hostname()
if isObviousPrivateHost(hostname) {
if isObviousPrivateHost(hostname, t.whitelist) {
return ErrorResult("fetching private or local network hosts is not allowed")
}
@@ -981,7 +1001,10 @@ func (t *WebFetchTool) extractText(htmlContent string) string {
// newSafeDialContext re-resolves DNS at connect time to mitigate DNS rebinding (TOCTOU)
// where a hostname resolves to a public IP during pre-flight but a private IP at connect time.
func newSafeDialContext(dialer *net.Dialer) func(context.Context, string, string) (net.Conn, error) {
func newSafeDialContext(
dialer *net.Dialer,
whitelist *privateHostWhitelist,
) func(context.Context, string, string) (net.Conn, error) {
return func(ctx context.Context, network, address string) (net.Conn, error) {
if allowPrivateWebFetchHosts.Load() {
return dialer.DialContext(ctx, network, address)
@@ -996,7 +1019,7 @@ func newSafeDialContext(dialer *net.Dialer) func(context.Context, string, string
}
if ip := net.ParseIP(host); ip != nil {
if isPrivateOrRestrictedIP(ip) {
if shouldBlockPrivateIP(ip, whitelist) {
return nil, fmt.Errorf("blocked private or local target: %s", host)
}
return dialer.DialContext(ctx, network, net.JoinHostPort(ip.String(), port))
@@ -1010,7 +1033,7 @@ func newSafeDialContext(dialer *net.Dialer) func(context.Context, string, string
attempted := 0
var lastErr error
for _, ipAddr := range ipAddrs {
if isPrivateOrRestrictedIP(ipAddr.IP) {
if shouldBlockPrivateIP(ipAddr.IP, whitelist) {
continue
}
attempted++
@@ -1022,7 +1045,7 @@ func newSafeDialContext(dialer *net.Dialer) func(context.Context, string, string
}
if attempted == 0 {
return nil, fmt.Errorf("all resolved addresses for %s are private or restricted", host)
return nil, fmt.Errorf("all resolved addresses for %s are private, restricted, or not whitelisted", host)
}
if lastErr != nil {
return nil, fmt.Errorf("failed connecting to public addresses for %s: %w", host, lastErr)
@@ -1031,10 +1054,72 @@ func newSafeDialContext(dialer *net.Dialer) func(context.Context, string, string
}
}
func newPrivateHostWhitelist(entries []string) (*privateHostWhitelist, error) {
if len(entries) == 0 {
return nil, nil
}
whitelist := &privateHostWhitelist{
exact: make(map[string]struct{}),
cidrs: make([]*net.IPNet, 0, len(entries)),
}
for _, entry := range entries {
entry = strings.TrimSpace(entry)
if entry == "" {
continue
}
if ip := net.ParseIP(entry); ip != nil {
whitelist.exact[normalizeWhitelistIP(ip).String()] = struct{}{}
continue
}
_, network, err := net.ParseCIDR(entry)
if err != nil {
return nil, fmt.Errorf("invalid entry %q: expected IP or CIDR", entry)
}
whitelist.cidrs = append(whitelist.cidrs, network)
}
if len(whitelist.exact) == 0 && len(whitelist.cidrs) == 0 {
return nil, nil
}
return whitelist, nil
}
func (w *privateHostWhitelist) Contains(ip net.IP) bool {
if w == nil || ip == nil {
return false
}
normalized := normalizeWhitelistIP(ip)
if _, ok := w.exact[normalized.String()]; ok {
return true
}
for _, network := range w.cidrs {
if network.Contains(normalized) {
return true
}
}
return false
}
func normalizeWhitelistIP(ip net.IP) net.IP {
if ip == nil {
return nil
}
if ip4 := ip.To4(); ip4 != nil {
return ip4
}
return ip
}
func shouldBlockPrivateIP(ip net.IP, whitelist *privateHostWhitelist) bool {
return isPrivateOrRestrictedIP(ip) && !whitelist.Contains(ip)
}
// isObviousPrivateHost performs a lightweight, no-DNS check for obviously private hosts.
// It catches localhost, literal private IPs, and empty hosts. It does NOT resolve DNS —
// the real SSRF guard is newSafeDialContext which checks IPs at connect time.
func isObviousPrivateHost(host string) bool {
func isObviousPrivateHost(host string, whitelist *privateHostWhitelist) bool {
if allowPrivateWebFetchHosts.Load() {
return false
}
@@ -1050,7 +1135,7 @@ func isObviousPrivateHost(host string) bool {
}
if ip := net.ParseIP(h); ip != nil {
return isPrivateOrRestrictedIP(ip)
return shouldBlockPrivateIP(ip, whitelist)
}
return false
+147
View File
@@ -10,6 +10,7 @@ import (
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/sipeed/picoclaw/pkg/logger"
)
@@ -423,6 +424,29 @@ func withPrivateWebFetchHostsAllowed(t *testing.T) {
})
}
func serverHostAndPort(t *testing.T, rawURL string) (string, string) {
t.Helper()
hostPort := strings.TrimPrefix(rawURL, "http://")
hostPort = strings.TrimPrefix(hostPort, "https://")
host, port, err := net.SplitHostPort(hostPort)
if err != nil {
t.Fatalf("failed to split host/port from %q: %v", rawURL, err)
}
return host, port
}
func singleHostCIDR(t *testing.T, host string) string {
t.Helper()
ip := net.ParseIP(host)
if ip == nil {
t.Fatalf("failed to parse IP %q", host)
}
if ip.To4() != nil {
return ip.String() + "/32"
}
return ip.String() + "/128"
}
func TestWebTool_WebFetch_PrivateHostBlocked(t *testing.T) {
tool, err := NewWebFetchTool(50000, testFetchLimit)
if err != nil {
@@ -441,6 +465,56 @@ func TestWebTool_WebFetch_PrivateHostBlocked(t *testing.T) {
}
}
func TestWebTool_WebFetch_PrivateHostAllowedByExactWhitelist(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain")
w.WriteHeader(http.StatusOK)
w.Write([]byte("exact whitelist ok"))
}))
defer server.Close()
host, _ := serverHostAndPort(t, server.URL)
tool, err := NewWebFetchToolWithConfig(50000, "", testFetchLimit, []string{host})
if err != nil {
t.Fatalf("Failed to create web fetch tool: %v", err)
}
result := tool.Execute(context.Background(), map[string]any{
"url": server.URL,
})
if result.IsError {
t.Fatalf("expected success for exact whitelisted private IP, got %q", result.ForLLM)
}
if !strings.Contains(result.ForLLM, "exact whitelist ok") {
t.Fatalf("expected fetched content, got %q", result.ForLLM)
}
}
func TestWebTool_WebFetch_PrivateHostAllowedByCIDRWhitelist(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain")
w.WriteHeader(http.StatusOK)
w.Write([]byte("cidr whitelist ok"))
}))
defer server.Close()
host, _ := serverHostAndPort(t, server.URL)
tool, err := NewWebFetchToolWithConfig(50000, "", testFetchLimit, []string{singleHostCIDR(t, host)})
if err != nil {
t.Fatalf("Failed to create web fetch tool: %v", err)
}
result := tool.Execute(context.Background(), map[string]any{
"url": server.URL,
})
if result.IsError {
t.Fatalf("expected success for CIDR-whitelisted private IP, got %q", result.ForLLM)
}
if !strings.Contains(result.ForLLM, "cidr whitelist ok") {
t.Fatalf("expected fetched content, got %q", result.ForLLM)
}
}
func TestWebTool_WebFetch_PrivateHostAllowedForTests(t *testing.T) {
withPrivateWebFetchHostsAllowed(t)
@@ -570,6 +644,69 @@ func TestWebFetch_RedirectToPrivateBlocked(t *testing.T) {
}
}
func TestNewSafeDialContext_BlocksPrivateDNSResolutionWithoutWhitelist(t *testing.T) {
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("failed to listen on loopback: %v", err)
}
defer listener.Close()
_, port, err := net.SplitHostPort(listener.Addr().String())
if err != nil {
t.Fatalf("failed to split listener address: %v", err)
}
dialContext := newSafeDialContext(&net.Dialer{Timeout: time.Second}, nil)
_, err = dialContext(context.Background(), "tcp", net.JoinHostPort("localhost", port))
if err == nil {
t.Fatal("expected localhost DNS resolution to be blocked without whitelist")
}
if !strings.Contains(err.Error(), "private") && !strings.Contains(err.Error(), "whitelisted") {
t.Fatalf("unexpected error: %v", err)
}
}
func TestNewSafeDialContext_AllowsWhitelistedPrivateDNSResolution(t *testing.T) {
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("failed to listen on loopback: %v", err)
}
defer listener.Close()
accepted := make(chan struct{}, 1)
go func() {
conn, acceptErr := listener.Accept()
if acceptErr != nil {
return
}
conn.Close()
accepted <- struct{}{}
}()
_, port, err := net.SplitHostPort(listener.Addr().String())
if err != nil {
t.Fatalf("failed to split listener address: %v", err)
}
whitelist, err := newPrivateHostWhitelist([]string{"127.0.0.0/8"})
if err != nil {
t.Fatalf("failed to parse whitelist: %v", err)
}
dialContext := newSafeDialContext(&net.Dialer{Timeout: time.Second}, whitelist)
conn, err := dialContext(context.Background(), "tcp", net.JoinHostPort("localhost", port))
if err != nil {
t.Fatalf("expected localhost DNS resolution to succeed with whitelist, got %v", err)
}
conn.Close()
select {
case <-accepted:
case <-time.After(time.Second):
t.Fatal("expected localhost listener to accept a connection")
}
}
// TestIsPrivateOrRestrictedIP_Table tests IP classification logic
func TestIsPrivateOrRestrictedIP_Table(t *testing.T) {
tests := []struct {
@@ -660,6 +797,16 @@ func TestNewWebFetchToolWithProxy(t *testing.T) {
}
}
func TestNewWebFetchToolWithConfig_InvalidPrivateHostWhitelist(t *testing.T) {
_, err := NewWebFetchToolWithConfig(1024, "", testFetchLimit, []string{"not-an-ip-or-cidr"})
if err == nil {
t.Fatal("expected invalid whitelist entry to fail")
}
if !strings.Contains(err.Error(), "invalid entry") {
t.Fatalf("unexpected error: %v", err)
}
}
func TestNewWebSearchTool_PropagatesProxy(t *testing.T) {
t.Run("perplexity", func(t *testing.T) {
tool, err := NewWebSearchTool(WebSearchToolOptions{