Merge branch 'main' into feat/markdown-output-format-web-fetch

This commit is contained in:
Mauro
2026-03-17 16:37:22 +01:00
committed by GitHub
104 changed files with 6151 additions and 1202 deletions
+95 -10
View File
@@ -780,11 +780,17 @@ type WebFetchTool struct {
client *http.Client
format string
fetchLimitBytes int64
whitelist *privateHostWhitelist
}
type privateHostWhitelist struct {
exact map[string]struct{}
cidrs []*net.IPNet
}
func NewWebFetchTool(maxChars int, format string, fetchLimitBytes int64) (*WebFetchTool, error) {
// createHTTPClient cannot fail with an empty proxy string.
return NewWebFetchToolWithProxy(maxChars, "", format, fetchLimitBytes)
return NewWebFetchToolWithProxy(maxChars, "", format, fetchLimitBytes, nil)
}
// allowPrivateWebFetchHosts controls whether loopback/private hosts are allowed.
@@ -792,9 +798,22 @@ func NewWebFetchTool(maxChars int, format string, fetchLimitBytes int64) (*WebFe
var allowPrivateWebFetchHosts atomic.Bool
func NewWebFetchToolWithProxy(maxChars int, proxy string, format string, fetchLimitBytes int64) (*WebFetchTool, error) {
return NewWebFetchToolWithConfig(maxChars, proxy, fetchLimitBytes, nil)
}
func NewWebFetchToolWithConfig(
maxChars int,
proxy string,
fetchLimitBytes int64,
privateHostWhitelist []string,
) (*WebFetchTool, error) {
if maxChars <= 0 {
maxChars = defaultMaxChars
}
whitelist, err := newPrivateHostWhitelist(privateHostWhitelist)
if err != nil {
return nil, fmt.Errorf("failed to parse web fetch private host whitelist: %w", err)
}
client, err := utils.CreateHTTPClient(proxy, fetchTimeout)
if err != nil {
return nil, fmt.Errorf("failed to create HTTP client for web fetch: %w", err)
@@ -804,13 +823,13 @@ func NewWebFetchToolWithProxy(maxChars int, proxy string, format string, fetchLi
Timeout: 15 * time.Second,
KeepAlive: 30 * time.Second,
}
transport.DialContext = newSafeDialContext(dialer)
transport.DialContext = newSafeDialContext(dialer, whitelist)
}
client.CheckRedirect = func(req *http.Request, via []*http.Request) error {
if len(via) >= maxRedirects {
return fmt.Errorf("stopped after %d redirects", maxRedirects)
}
if isObviousPrivateHost(req.URL.Hostname()) {
if isObviousPrivateHost(req.URL.Hostname(), whitelist) {
return fmt.Errorf("redirect target is private or local network host")
}
return nil
@@ -824,6 +843,7 @@ func NewWebFetchToolWithProxy(maxChars int, proxy string, format string, fetchLi
client: client,
format: format,
fetchLimitBytes: fetchLimitBytes,
whitelist: whitelist,
}, nil
}
@@ -875,7 +895,7 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]any) *ToolRe
// Lightweight pre-flight: block obvious localhost/literal-IP without DNS resolution.
// The real SSRF guard is newSafeDialContext at connect time.
hostname := parsedURL.Hostname()
if isObviousPrivateHost(hostname) {
if isObviousPrivateHost(hostname, t.whitelist) {
return ErrorResult("fetching private or local network hosts is not allowed")
}
@@ -1019,7 +1039,10 @@ func (t *WebFetchTool) extractText(htmlContent string) string {
// newSafeDialContext re-resolves DNS at connect time to mitigate DNS rebinding (TOCTOU)
// where a hostname resolves to a public IP during pre-flight but a private IP at connect time.
func newSafeDialContext(dialer *net.Dialer) func(context.Context, string, string) (net.Conn, error) {
func newSafeDialContext(
dialer *net.Dialer,
whitelist *privateHostWhitelist,
) func(context.Context, string, string) (net.Conn, error) {
return func(ctx context.Context, network, address string) (net.Conn, error) {
if allowPrivateWebFetchHosts.Load() {
return dialer.DialContext(ctx, network, address)
@@ -1034,7 +1057,7 @@ func newSafeDialContext(dialer *net.Dialer) func(context.Context, string, string
}
if ip := net.ParseIP(host); ip != nil {
if isPrivateOrRestrictedIP(ip) {
if shouldBlockPrivateIP(ip, whitelist) {
return nil, fmt.Errorf("blocked private or local target: %s", host)
}
return dialer.DialContext(ctx, network, net.JoinHostPort(ip.String(), port))
@@ -1048,7 +1071,7 @@ func newSafeDialContext(dialer *net.Dialer) func(context.Context, string, string
attempted := 0
var lastErr error
for _, ipAddr := range ipAddrs {
if isPrivateOrRestrictedIP(ipAddr.IP) {
if shouldBlockPrivateIP(ipAddr.IP, whitelist) {
continue
}
attempted++
@@ -1060,7 +1083,7 @@ func newSafeDialContext(dialer *net.Dialer) func(context.Context, string, string
}
if attempted == 0 {
return nil, fmt.Errorf("all resolved addresses for %s are private or restricted", host)
return nil, fmt.Errorf("all resolved addresses for %s are private, restricted, or not whitelisted", host)
}
if lastErr != nil {
return nil, fmt.Errorf("failed connecting to public addresses for %s: %w", host, lastErr)
@@ -1069,10 +1092,72 @@ func newSafeDialContext(dialer *net.Dialer) func(context.Context, string, string
}
}
func newPrivateHostWhitelist(entries []string) (*privateHostWhitelist, error) {
if len(entries) == 0 {
return nil, nil
}
whitelist := &privateHostWhitelist{
exact: make(map[string]struct{}),
cidrs: make([]*net.IPNet, 0, len(entries)),
}
for _, entry := range entries {
entry = strings.TrimSpace(entry)
if entry == "" {
continue
}
if ip := net.ParseIP(entry); ip != nil {
whitelist.exact[normalizeWhitelistIP(ip).String()] = struct{}{}
continue
}
_, network, err := net.ParseCIDR(entry)
if err != nil {
return nil, fmt.Errorf("invalid entry %q: expected IP or CIDR", entry)
}
whitelist.cidrs = append(whitelist.cidrs, network)
}
if len(whitelist.exact) == 0 && len(whitelist.cidrs) == 0 {
return nil, nil
}
return whitelist, nil
}
func (w *privateHostWhitelist) Contains(ip net.IP) bool {
if w == nil || ip == nil {
return false
}
normalized := normalizeWhitelistIP(ip)
if _, ok := w.exact[normalized.String()]; ok {
return true
}
for _, network := range w.cidrs {
if network.Contains(normalized) {
return true
}
}
return false
}
func normalizeWhitelistIP(ip net.IP) net.IP {
if ip == nil {
return nil
}
if ip4 := ip.To4(); ip4 != nil {
return ip4
}
return ip
}
func shouldBlockPrivateIP(ip net.IP, whitelist *privateHostWhitelist) bool {
return isPrivateOrRestrictedIP(ip) && !whitelist.Contains(ip)
}
// isObviousPrivateHost performs a lightweight, no-DNS check for obviously private hosts.
// It catches localhost, literal private IPs, and empty hosts. It does NOT resolve DNS —
// the real SSRF guard is newSafeDialContext which checks IPs at connect time.
func isObviousPrivateHost(host string) bool {
func isObviousPrivateHost(host string, whitelist *privateHostWhitelist) bool {
if allowPrivateWebFetchHosts.Load() {
return false
}
@@ -1088,7 +1173,7 @@ func isObviousPrivateHost(host string) bool {
}
if ip := net.ParseIP(h); ip != nil {
return isPrivateOrRestrictedIP(ip)
return shouldBlockPrivateIP(ip, whitelist)
}
return false