mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
feat(web): whitelist private fetch targets (#1688)
* feat(web): whitelist private fetch targets * test(web): avoid accept error shadowing --------- Co-authored-by: Alix-007 <267018309+Alix-007@users.noreply.github.com>
This commit is contained in:
@@ -351,7 +351,8 @@
|
||||
"search_engine": "search_std",
|
||||
"max_results": 5
|
||||
},
|
||||
"fetch_limit_bytes": 10485760
|
||||
"fetch_limit_bytes": 10485760,
|
||||
"private_host_whitelist": []
|
||||
},
|
||||
"cron": {
|
||||
"enabled": true,
|
||||
|
||||
+6
-1
@@ -159,7 +159,12 @@ func registerSharedTools(
|
||||
}
|
||||
}
|
||||
if cfg.Tools.IsToolEnabled("web_fetch") {
|
||||
fetchTool, err := tools.NewWebFetchToolWithProxy(50000, cfg.Tools.Web.Proxy, cfg.Tools.Web.FetchLimitBytes)
|
||||
fetchTool, err := tools.NewWebFetchToolWithConfig(
|
||||
50000,
|
||||
cfg.Tools.Web.Proxy,
|
||||
cfg.Tools.Web.FetchLimitBytes,
|
||||
cfg.Tools.Web.PrivateHostWhitelist,
|
||||
)
|
||||
if err != nil {
|
||||
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
|
||||
} else {
|
||||
|
||||
@@ -695,8 +695,9 @@ type WebToolsConfig struct {
|
||||
GLMSearch GLMSearchConfig ` json:"glm_search"`
|
||||
// Proxy is an optional proxy URL for web tools (http/https/socks5/socks5h).
|
||||
// For authenticated proxies, prefer HTTP_PROXY/HTTPS_PROXY env vars instead of embedding credentials in config.
|
||||
Proxy string `json:"proxy,omitempty" env:"PICOCLAW_TOOLS_WEB_PROXY"`
|
||||
FetchLimitBytes int64 `json:"fetch_limit_bytes,omitempty" env:"PICOCLAW_TOOLS_WEB_FETCH_LIMIT_BYTES"`
|
||||
Proxy string `json:"proxy,omitempty" env:"PICOCLAW_TOOLS_WEB_PROXY"`
|
||||
FetchLimitBytes int64 `json:"fetch_limit_bytes,omitempty" env:"PICOCLAW_TOOLS_WEB_FETCH_LIMIT_BYTES"`
|
||||
PrivateHostWhitelist FlexibleStringSlice `json:"private_host_whitelist,omitempty" env:"PICOCLAW_TOOLS_WEB_PRIVATE_HOST_WHITELIST"`
|
||||
}
|
||||
|
||||
type CronToolsConfig struct {
|
||||
|
||||
+95
-10
@@ -777,11 +777,17 @@ type WebFetchTool struct {
|
||||
proxy string
|
||||
client *http.Client
|
||||
fetchLimitBytes int64
|
||||
whitelist *privateHostWhitelist
|
||||
}
|
||||
|
||||
type privateHostWhitelist struct {
|
||||
exact map[string]struct{}
|
||||
cidrs []*net.IPNet
|
||||
}
|
||||
|
||||
func NewWebFetchTool(maxChars int, fetchLimitBytes int64) (*WebFetchTool, error) {
|
||||
// createHTTPClient cannot fail with an empty proxy string.
|
||||
return NewWebFetchToolWithProxy(maxChars, "", fetchLimitBytes)
|
||||
return NewWebFetchToolWithConfig(maxChars, "", fetchLimitBytes, nil)
|
||||
}
|
||||
|
||||
// allowPrivateWebFetchHosts controls whether loopback/private hosts are allowed.
|
||||
@@ -789,9 +795,22 @@ func NewWebFetchTool(maxChars int, fetchLimitBytes int64) (*WebFetchTool, error)
|
||||
var allowPrivateWebFetchHosts atomic.Bool
|
||||
|
||||
func NewWebFetchToolWithProxy(maxChars int, proxy string, fetchLimitBytes int64) (*WebFetchTool, error) {
|
||||
return NewWebFetchToolWithConfig(maxChars, proxy, fetchLimitBytes, nil)
|
||||
}
|
||||
|
||||
func NewWebFetchToolWithConfig(
|
||||
maxChars int,
|
||||
proxy string,
|
||||
fetchLimitBytes int64,
|
||||
privateHostWhitelist []string,
|
||||
) (*WebFetchTool, error) {
|
||||
if maxChars <= 0 {
|
||||
maxChars = defaultMaxChars
|
||||
}
|
||||
whitelist, err := newPrivateHostWhitelist(privateHostWhitelist)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse web fetch private host whitelist: %w", err)
|
||||
}
|
||||
client, err := utils.CreateHTTPClient(proxy, fetchTimeout)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create HTTP client for web fetch: %w", err)
|
||||
@@ -801,13 +820,13 @@ func NewWebFetchToolWithProxy(maxChars int, proxy string, fetchLimitBytes int64)
|
||||
Timeout: 15 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
}
|
||||
transport.DialContext = newSafeDialContext(dialer)
|
||||
transport.DialContext = newSafeDialContext(dialer, whitelist)
|
||||
}
|
||||
client.CheckRedirect = func(req *http.Request, via []*http.Request) error {
|
||||
if len(via) >= maxRedirects {
|
||||
return fmt.Errorf("stopped after %d redirects", maxRedirects)
|
||||
}
|
||||
if isObviousPrivateHost(req.URL.Hostname()) {
|
||||
if isObviousPrivateHost(req.URL.Hostname(), whitelist) {
|
||||
return fmt.Errorf("redirect target is private or local network host")
|
||||
}
|
||||
return nil
|
||||
@@ -820,6 +839,7 @@ func NewWebFetchToolWithProxy(maxChars int, proxy string, fetchLimitBytes int64)
|
||||
proxy: proxy,
|
||||
client: client,
|
||||
fetchLimitBytes: fetchLimitBytes,
|
||||
whitelist: whitelist,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -871,7 +891,7 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]any) *ToolRe
|
||||
// Lightweight pre-flight: block obvious localhost/literal-IP without DNS resolution.
|
||||
// The real SSRF guard is newSafeDialContext at connect time.
|
||||
hostname := parsedURL.Hostname()
|
||||
if isObviousPrivateHost(hostname) {
|
||||
if isObviousPrivateHost(hostname, t.whitelist) {
|
||||
return ErrorResult("fetching private or local network hosts is not allowed")
|
||||
}
|
||||
|
||||
@@ -981,7 +1001,10 @@ func (t *WebFetchTool) extractText(htmlContent string) string {
|
||||
|
||||
// newSafeDialContext re-resolves DNS at connect time to mitigate DNS rebinding (TOCTOU)
|
||||
// where a hostname resolves to a public IP during pre-flight but a private IP at connect time.
|
||||
func newSafeDialContext(dialer *net.Dialer) func(context.Context, string, string) (net.Conn, error) {
|
||||
func newSafeDialContext(
|
||||
dialer *net.Dialer,
|
||||
whitelist *privateHostWhitelist,
|
||||
) func(context.Context, string, string) (net.Conn, error) {
|
||||
return func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
if allowPrivateWebFetchHosts.Load() {
|
||||
return dialer.DialContext(ctx, network, address)
|
||||
@@ -996,7 +1019,7 @@ func newSafeDialContext(dialer *net.Dialer) func(context.Context, string, string
|
||||
}
|
||||
|
||||
if ip := net.ParseIP(host); ip != nil {
|
||||
if isPrivateOrRestrictedIP(ip) {
|
||||
if shouldBlockPrivateIP(ip, whitelist) {
|
||||
return nil, fmt.Errorf("blocked private or local target: %s", host)
|
||||
}
|
||||
return dialer.DialContext(ctx, network, net.JoinHostPort(ip.String(), port))
|
||||
@@ -1010,7 +1033,7 @@ func newSafeDialContext(dialer *net.Dialer) func(context.Context, string, string
|
||||
attempted := 0
|
||||
var lastErr error
|
||||
for _, ipAddr := range ipAddrs {
|
||||
if isPrivateOrRestrictedIP(ipAddr.IP) {
|
||||
if shouldBlockPrivateIP(ipAddr.IP, whitelist) {
|
||||
continue
|
||||
}
|
||||
attempted++
|
||||
@@ -1022,7 +1045,7 @@ func newSafeDialContext(dialer *net.Dialer) func(context.Context, string, string
|
||||
}
|
||||
|
||||
if attempted == 0 {
|
||||
return nil, fmt.Errorf("all resolved addresses for %s are private or restricted", host)
|
||||
return nil, fmt.Errorf("all resolved addresses for %s are private, restricted, or not whitelisted", host)
|
||||
}
|
||||
if lastErr != nil {
|
||||
return nil, fmt.Errorf("failed connecting to public addresses for %s: %w", host, lastErr)
|
||||
@@ -1031,10 +1054,72 @@ func newSafeDialContext(dialer *net.Dialer) func(context.Context, string, string
|
||||
}
|
||||
}
|
||||
|
||||
func newPrivateHostWhitelist(entries []string) (*privateHostWhitelist, error) {
|
||||
if len(entries) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
whitelist := &privateHostWhitelist{
|
||||
exact: make(map[string]struct{}),
|
||||
cidrs: make([]*net.IPNet, 0, len(entries)),
|
||||
}
|
||||
for _, entry := range entries {
|
||||
entry = strings.TrimSpace(entry)
|
||||
if entry == "" {
|
||||
continue
|
||||
}
|
||||
if ip := net.ParseIP(entry); ip != nil {
|
||||
whitelist.exact[normalizeWhitelistIP(ip).String()] = struct{}{}
|
||||
continue
|
||||
}
|
||||
_, network, err := net.ParseCIDR(entry)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid entry %q: expected IP or CIDR", entry)
|
||||
}
|
||||
whitelist.cidrs = append(whitelist.cidrs, network)
|
||||
}
|
||||
|
||||
if len(whitelist.exact) == 0 && len(whitelist.cidrs) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
return whitelist, nil
|
||||
}
|
||||
|
||||
func (w *privateHostWhitelist) Contains(ip net.IP) bool {
|
||||
if w == nil || ip == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
normalized := normalizeWhitelistIP(ip)
|
||||
if _, ok := w.exact[normalized.String()]; ok {
|
||||
return true
|
||||
}
|
||||
for _, network := range w.cidrs {
|
||||
if network.Contains(normalized) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func normalizeWhitelistIP(ip net.IP) net.IP {
|
||||
if ip == nil {
|
||||
return nil
|
||||
}
|
||||
if ip4 := ip.To4(); ip4 != nil {
|
||||
return ip4
|
||||
}
|
||||
return ip
|
||||
}
|
||||
|
||||
func shouldBlockPrivateIP(ip net.IP, whitelist *privateHostWhitelist) bool {
|
||||
return isPrivateOrRestrictedIP(ip) && !whitelist.Contains(ip)
|
||||
}
|
||||
|
||||
// isObviousPrivateHost performs a lightweight, no-DNS check for obviously private hosts.
|
||||
// It catches localhost, literal private IPs, and empty hosts. It does NOT resolve DNS —
|
||||
// the real SSRF guard is newSafeDialContext which checks IPs at connect time.
|
||||
func isObviousPrivateHost(host string) bool {
|
||||
func isObviousPrivateHost(host string, whitelist *privateHostWhitelist) bool {
|
||||
if allowPrivateWebFetchHosts.Load() {
|
||||
return false
|
||||
}
|
||||
@@ -1050,7 +1135,7 @@ func isObviousPrivateHost(host string) bool {
|
||||
}
|
||||
|
||||
if ip := net.ParseIP(h); ip != nil {
|
||||
return isPrivateOrRestrictedIP(ip)
|
||||
return shouldBlockPrivateIP(ip, whitelist)
|
||||
}
|
||||
|
||||
return false
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
)
|
||||
@@ -423,6 +424,29 @@ func withPrivateWebFetchHostsAllowed(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func serverHostAndPort(t *testing.T, rawURL string) (string, string) {
|
||||
t.Helper()
|
||||
hostPort := strings.TrimPrefix(rawURL, "http://")
|
||||
hostPort = strings.TrimPrefix(hostPort, "https://")
|
||||
host, port, err := net.SplitHostPort(hostPort)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to split host/port from %q: %v", rawURL, err)
|
||||
}
|
||||
return host, port
|
||||
}
|
||||
|
||||
func singleHostCIDR(t *testing.T, host string) string {
|
||||
t.Helper()
|
||||
ip := net.ParseIP(host)
|
||||
if ip == nil {
|
||||
t.Fatalf("failed to parse IP %q", host)
|
||||
}
|
||||
if ip.To4() != nil {
|
||||
return ip.String() + "/32"
|
||||
}
|
||||
return ip.String() + "/128"
|
||||
}
|
||||
|
||||
func TestWebTool_WebFetch_PrivateHostBlocked(t *testing.T) {
|
||||
tool, err := NewWebFetchTool(50000, testFetchLimit)
|
||||
if err != nil {
|
||||
@@ -441,6 +465,56 @@ func TestWebTool_WebFetch_PrivateHostBlocked(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebTool_WebFetch_PrivateHostAllowedByExactWhitelist(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("exact whitelist ok"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
host, _ := serverHostAndPort(t, server.URL)
|
||||
tool, err := NewWebFetchToolWithConfig(50000, "", testFetchLimit, []string{host})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create web fetch tool: %v", err)
|
||||
}
|
||||
|
||||
result := tool.Execute(context.Background(), map[string]any{
|
||||
"url": server.URL,
|
||||
})
|
||||
if result.IsError {
|
||||
t.Fatalf("expected success for exact whitelisted private IP, got %q", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "exact whitelist ok") {
|
||||
t.Fatalf("expected fetched content, got %q", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebTool_WebFetch_PrivateHostAllowedByCIDRWhitelist(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("cidr whitelist ok"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
host, _ := serverHostAndPort(t, server.URL)
|
||||
tool, err := NewWebFetchToolWithConfig(50000, "", testFetchLimit, []string{singleHostCIDR(t, host)})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create web fetch tool: %v", err)
|
||||
}
|
||||
|
||||
result := tool.Execute(context.Background(), map[string]any{
|
||||
"url": server.URL,
|
||||
})
|
||||
if result.IsError {
|
||||
t.Fatalf("expected success for CIDR-whitelisted private IP, got %q", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "cidr whitelist ok") {
|
||||
t.Fatalf("expected fetched content, got %q", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebTool_WebFetch_PrivateHostAllowedForTests(t *testing.T) {
|
||||
withPrivateWebFetchHostsAllowed(t)
|
||||
|
||||
@@ -570,6 +644,69 @@ func TestWebFetch_RedirectToPrivateBlocked(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewSafeDialContext_BlocksPrivateDNSResolutionWithoutWhitelist(t *testing.T) {
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to listen on loopback: %v", err)
|
||||
}
|
||||
defer listener.Close()
|
||||
|
||||
_, port, err := net.SplitHostPort(listener.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatalf("failed to split listener address: %v", err)
|
||||
}
|
||||
|
||||
dialContext := newSafeDialContext(&net.Dialer{Timeout: time.Second}, nil)
|
||||
_, err = dialContext(context.Background(), "tcp", net.JoinHostPort("localhost", port))
|
||||
if err == nil {
|
||||
t.Fatal("expected localhost DNS resolution to be blocked without whitelist")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "private") && !strings.Contains(err.Error(), "whitelisted") {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewSafeDialContext_AllowsWhitelistedPrivateDNSResolution(t *testing.T) {
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to listen on loopback: %v", err)
|
||||
}
|
||||
defer listener.Close()
|
||||
|
||||
accepted := make(chan struct{}, 1)
|
||||
go func() {
|
||||
conn, acceptErr := listener.Accept()
|
||||
if acceptErr != nil {
|
||||
return
|
||||
}
|
||||
conn.Close()
|
||||
accepted <- struct{}{}
|
||||
}()
|
||||
|
||||
_, port, err := net.SplitHostPort(listener.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatalf("failed to split listener address: %v", err)
|
||||
}
|
||||
|
||||
whitelist, err := newPrivateHostWhitelist([]string{"127.0.0.0/8"})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse whitelist: %v", err)
|
||||
}
|
||||
|
||||
dialContext := newSafeDialContext(&net.Dialer{Timeout: time.Second}, whitelist)
|
||||
conn, err := dialContext(context.Background(), "tcp", net.JoinHostPort("localhost", port))
|
||||
if err != nil {
|
||||
t.Fatalf("expected localhost DNS resolution to succeed with whitelist, got %v", err)
|
||||
}
|
||||
conn.Close()
|
||||
|
||||
select {
|
||||
case <-accepted:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("expected localhost listener to accept a connection")
|
||||
}
|
||||
}
|
||||
|
||||
// TestIsPrivateOrRestrictedIP_Table tests IP classification logic
|
||||
func TestIsPrivateOrRestrictedIP_Table(t *testing.T) {
|
||||
tests := []struct {
|
||||
@@ -660,6 +797,16 @@ func TestNewWebFetchToolWithProxy(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewWebFetchToolWithConfig_InvalidPrivateHostWhitelist(t *testing.T) {
|
||||
_, err := NewWebFetchToolWithConfig(1024, "", testFetchLimit, []string{"not-an-ip-or-cidr"})
|
||||
if err == nil {
|
||||
t.Fatal("expected invalid whitelist entry to fail")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "invalid entry") {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewWebSearchTool_PropagatesProxy(t *testing.T) {
|
||||
t.Run("perplexity", func(t *testing.T) {
|
||||
tool, err := NewWebSearchTool(WebSearchToolOptions{
|
||||
|
||||
Reference in New Issue
Block a user