fix(security): harden unauthenticated tool-exec paths (#1360)

* fix(security): harden unauthenticated tool-exec paths (GHSA-pv8c-p6jf-3fpp)

- Exec tool: channel-based access control (default deny remote)
- Cron tool: command scheduling restricted to internal channels
- Web fetch: SSRF defense-in-depth (pre-flight + dial-time + redirect checks)
- File permissions: session/state dirs 0700, files 0600
- Registry: inject __channel/__chat_id into tool args (replaces racy SetContext)

28 new security regression tests.

(cherry picked from commit 191446ae19021604d3d5b0d9376b9655ab749105)

* fix(exec): revalidate working_dir before command start

* test(web): allow local oversized payload fixture

---------

Co-authored-by: xj <gh-xj@users.noreply.github.com>
This commit is contained in:
wenjie
2026-03-11 19:22:20 +08:00
committed by GitHub
parent dea06c391c
commit 8c2a9332c6
14 changed files with 622 additions and 30 deletions
+146 -1
View File
@@ -7,6 +7,7 @@ import (
"errors"
"fmt"
"io"
"net"
"net/http"
"net/url"
"regexp"
@@ -818,6 +819,10 @@ func NewWebFetchTool(maxChars int, fetchLimitBytes int64) (*WebFetchTool, error)
return NewWebFetchToolWithProxy(maxChars, "", fetchLimitBytes)
}
// allowPrivateWebFetchHosts controls whether loopback/private hosts are allowed.
// This is false in normal runtime to reduce SSRF exposure, and tests can override it temporarily.
var allowPrivateWebFetchHosts atomic.Bool
func NewWebFetchToolWithProxy(maxChars int, proxy string, fetchLimitBytes int64) (*WebFetchTool, error) {
if maxChars <= 0 {
maxChars = defaultMaxChars
@@ -826,10 +831,20 @@ func NewWebFetchToolWithProxy(maxChars int, proxy string, fetchLimitBytes int64)
if err != nil {
return nil, fmt.Errorf("failed to create HTTP client for web fetch: %w", err)
}
if transport, ok := client.Transport.(*http.Transport); ok {
dialer := &net.Dialer{
Timeout: 15 * time.Second,
KeepAlive: 30 * time.Second,
}
transport.DialContext = newSafeDialContext(dialer)
}
client.CheckRedirect = func(req *http.Request, via []*http.Request) error {
if len(via) >= maxRedirects {
return fmt.Errorf("stopped after %d redirects", maxRedirects)
}
if isObviousPrivateHost(req.URL.Hostname()) {
return fmt.Errorf("redirect target is private or local network host")
}
return nil
}
if fetchLimitBytes <= 0 {
@@ -888,6 +903,13 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]any) *ToolRe
return ErrorResult("missing domain in URL")
}
// Lightweight pre-flight: block obvious localhost/literal-IP without DNS resolution.
// The real SSRF guard is newSafeDialContext at connect time.
hostname := parsedURL.Hostname()
if isObviousPrivateHost(hostname) {
return ErrorResult("fetching private or local network hosts is not allowed")
}
maxChars := t.maxChars
if mc, ok := args["maxChars"].(float64); ok {
if int(mc) > 100 {
@@ -901,7 +923,6 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]any) *ToolRe
}
req.Header.Set("User-Agent", userAgent)
resp, err := t.client.Do(req)
if err != nil {
return ErrorResult(fmt.Sprintf("request failed: %v", err))
@@ -992,3 +1013,127 @@ func (t *WebFetchTool) extractText(htmlContent string) string {
return strings.Join(cleanLines, "\n")
}
// newSafeDialContext re-resolves DNS at connect time to mitigate DNS rebinding (TOCTOU)
// where a hostname resolves to a public IP during pre-flight but a private IP at connect time.
func newSafeDialContext(dialer *net.Dialer) func(context.Context, string, string) (net.Conn, error) {
return func(ctx context.Context, network, address string) (net.Conn, error) {
if allowPrivateWebFetchHosts.Load() {
return dialer.DialContext(ctx, network, address)
}
host, port, err := net.SplitHostPort(address)
if err != nil {
return nil, fmt.Errorf("invalid target address %q: %w", address, err)
}
if host == "" {
return nil, fmt.Errorf("empty target host")
}
if ip := net.ParseIP(host); ip != nil {
if isPrivateOrRestrictedIP(ip) {
return nil, fmt.Errorf("blocked private or local target: %s", host)
}
return dialer.DialContext(ctx, network, net.JoinHostPort(ip.String(), port))
}
ipAddrs, err := net.DefaultResolver.LookupIPAddr(ctx, host)
if err != nil {
return nil, fmt.Errorf("failed to resolve %s: %w", host, err)
}
attempted := 0
var lastErr error
for _, ipAddr := range ipAddrs {
if isPrivateOrRestrictedIP(ipAddr.IP) {
continue
}
attempted++
conn, err := dialer.DialContext(ctx, network, net.JoinHostPort(ipAddr.IP.String(), port))
if err == nil {
return conn, nil
}
lastErr = err
}
if attempted == 0 {
return nil, fmt.Errorf("all resolved addresses for %s are private or restricted", host)
}
if lastErr != nil {
return nil, fmt.Errorf("failed connecting to public addresses for %s: %w", host, lastErr)
}
return nil, fmt.Errorf("failed connecting to public addresses for %s", host)
}
}
// isObviousPrivateHost performs a lightweight, no-DNS check for obviously private hosts.
// It catches localhost, literal private IPs, and empty hosts. It does NOT resolve DNS —
// the real SSRF guard is newSafeDialContext which checks IPs at connect time.
func isObviousPrivateHost(host string) bool {
if allowPrivateWebFetchHosts.Load() {
return false
}
h := strings.ToLower(strings.TrimSpace(host))
h = strings.TrimSuffix(h, ".")
if h == "" {
return true
}
if h == "localhost" || strings.HasSuffix(h, ".localhost") {
return true
}
if ip := net.ParseIP(h); ip != nil {
return isPrivateOrRestrictedIP(ip)
}
return false
}
// isPrivateOrRestrictedIP returns true for IPs that should never be reached via web_fetch:
// RFC 1918, loopback, link-local (incl. cloud metadata 169.254.x.x), carrier-grade NAT,
// IPv6 unique-local (fc00::/7), 6to4 (2002::/16), and Teredo (2001:0000::/32).
func isPrivateOrRestrictedIP(ip net.IP) bool {
if ip == nil {
return true
}
if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() ||
ip.IsMulticast() || ip.IsUnspecified() {
return true
}
if ip4 := ip.To4(); ip4 != nil {
// IPv4 private, loopback, link-local, and carrier-grade NAT ranges.
if ip4[0] == 10 ||
ip4[0] == 127 ||
ip4[0] == 0 ||
(ip4[0] == 172 && ip4[1] >= 16 && ip4[1] <= 31) ||
(ip4[0] == 192 && ip4[1] == 168) ||
(ip4[0] == 169 && ip4[1] == 254) ||
(ip4[0] == 100 && ip4[1] >= 64 && ip4[1] <= 127) {
return true
}
return false
}
if len(ip) == net.IPv6len {
// IPv6 unique local addresses (fc00::/7)
if (ip[0] & 0xfe) == 0xfc {
return true
}
// 6to4 addresses (2002::/16): check the embedded IPv4 at bytes [2:6].
if ip[0] == 0x20 && ip[1] == 0x02 {
embedded := net.IPv4(ip[2], ip[3], ip[4], ip[5])
return isPrivateOrRestrictedIP(embedded)
}
// Teredo (2001:0000::/32): client IPv4 is at bytes [12:16], XOR-inverted.
if ip[0] == 0x20 && ip[1] == 0x01 && ip[2] == 0x00 && ip[3] == 0x00 {
client := net.IPv4(ip[12]^0xff, ip[13]^0xff, ip[14]^0xff, ip[15]^0xff)
return isPrivateOrRestrictedIP(client)
}
}
return false
}