mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
fix(security): harden unauthenticated tool-exec paths (#1360)
* fix(security): harden unauthenticated tool-exec paths (GHSA-pv8c-p6jf-3fpp) - Exec tool: channel-based access control (default deny remote) - Cron tool: command scheduling restricted to internal channels - Web fetch: SSRF defense-in-depth (pre-flight + dial-time + redirect checks) - File permissions: session/state dirs 0700, files 0600 - Registry: inject __channel/__chat_id into tool args (replaces racy SetContext) 28 new security regression tests. (cherry picked from commit 191446ae19021604d3d5b0d9376b9655ab749105) * fix(exec): revalidate working_dir before command start * test(web): allow local oversized payload fixture --------- Co-authored-by: xj <gh-xj@users.noreply.github.com>
This commit is contained in:
+146
-1
@@ -7,6 +7,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"regexp"
|
||||
@@ -818,6 +819,10 @@ func NewWebFetchTool(maxChars int, fetchLimitBytes int64) (*WebFetchTool, error)
|
||||
return NewWebFetchToolWithProxy(maxChars, "", fetchLimitBytes)
|
||||
}
|
||||
|
||||
// allowPrivateWebFetchHosts controls whether loopback/private hosts are allowed.
|
||||
// This is false in normal runtime to reduce SSRF exposure, and tests can override it temporarily.
|
||||
var allowPrivateWebFetchHosts atomic.Bool
|
||||
|
||||
func NewWebFetchToolWithProxy(maxChars int, proxy string, fetchLimitBytes int64) (*WebFetchTool, error) {
|
||||
if maxChars <= 0 {
|
||||
maxChars = defaultMaxChars
|
||||
@@ -826,10 +831,20 @@ func NewWebFetchToolWithProxy(maxChars int, proxy string, fetchLimitBytes int64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create HTTP client for web fetch: %w", err)
|
||||
}
|
||||
if transport, ok := client.Transport.(*http.Transport); ok {
|
||||
dialer := &net.Dialer{
|
||||
Timeout: 15 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
}
|
||||
transport.DialContext = newSafeDialContext(dialer)
|
||||
}
|
||||
client.CheckRedirect = func(req *http.Request, via []*http.Request) error {
|
||||
if len(via) >= maxRedirects {
|
||||
return fmt.Errorf("stopped after %d redirects", maxRedirects)
|
||||
}
|
||||
if isObviousPrivateHost(req.URL.Hostname()) {
|
||||
return fmt.Errorf("redirect target is private or local network host")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if fetchLimitBytes <= 0 {
|
||||
@@ -888,6 +903,13 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]any) *ToolRe
|
||||
return ErrorResult("missing domain in URL")
|
||||
}
|
||||
|
||||
// Lightweight pre-flight: block obvious localhost/literal-IP without DNS resolution.
|
||||
// The real SSRF guard is newSafeDialContext at connect time.
|
||||
hostname := parsedURL.Hostname()
|
||||
if isObviousPrivateHost(hostname) {
|
||||
return ErrorResult("fetching private or local network hosts is not allowed")
|
||||
}
|
||||
|
||||
maxChars := t.maxChars
|
||||
if mc, ok := args["maxChars"].(float64); ok {
|
||||
if int(mc) > 100 {
|
||||
@@ -901,7 +923,6 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]any) *ToolRe
|
||||
}
|
||||
|
||||
req.Header.Set("User-Agent", userAgent)
|
||||
|
||||
resp, err := t.client.Do(req)
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("request failed: %v", err))
|
||||
@@ -992,3 +1013,127 @@ func (t *WebFetchTool) extractText(htmlContent string) string {
|
||||
|
||||
return strings.Join(cleanLines, "\n")
|
||||
}
|
||||
|
||||
// newSafeDialContext re-resolves DNS at connect time to mitigate DNS rebinding (TOCTOU)
|
||||
// where a hostname resolves to a public IP during pre-flight but a private IP at connect time.
|
||||
func newSafeDialContext(dialer *net.Dialer) func(context.Context, string, string) (net.Conn, error) {
|
||||
return func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
if allowPrivateWebFetchHosts.Load() {
|
||||
return dialer.DialContext(ctx, network, address)
|
||||
}
|
||||
|
||||
host, port, err := net.SplitHostPort(address)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid target address %q: %w", address, err)
|
||||
}
|
||||
if host == "" {
|
||||
return nil, fmt.Errorf("empty target host")
|
||||
}
|
||||
|
||||
if ip := net.ParseIP(host); ip != nil {
|
||||
if isPrivateOrRestrictedIP(ip) {
|
||||
return nil, fmt.Errorf("blocked private or local target: %s", host)
|
||||
}
|
||||
return dialer.DialContext(ctx, network, net.JoinHostPort(ip.String(), port))
|
||||
}
|
||||
|
||||
ipAddrs, err := net.DefaultResolver.LookupIPAddr(ctx, host)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to resolve %s: %w", host, err)
|
||||
}
|
||||
|
||||
attempted := 0
|
||||
var lastErr error
|
||||
for _, ipAddr := range ipAddrs {
|
||||
if isPrivateOrRestrictedIP(ipAddr.IP) {
|
||||
continue
|
||||
}
|
||||
attempted++
|
||||
conn, err := dialer.DialContext(ctx, network, net.JoinHostPort(ipAddr.IP.String(), port))
|
||||
if err == nil {
|
||||
return conn, nil
|
||||
}
|
||||
lastErr = err
|
||||
}
|
||||
|
||||
if attempted == 0 {
|
||||
return nil, fmt.Errorf("all resolved addresses for %s are private or restricted", host)
|
||||
}
|
||||
if lastErr != nil {
|
||||
return nil, fmt.Errorf("failed connecting to public addresses for %s: %w", host, lastErr)
|
||||
}
|
||||
return nil, fmt.Errorf("failed connecting to public addresses for %s", host)
|
||||
}
|
||||
}
|
||||
|
||||
// isObviousPrivateHost performs a lightweight, no-DNS check for obviously private hosts.
|
||||
// It catches localhost, literal private IPs, and empty hosts. It does NOT resolve DNS —
|
||||
// the real SSRF guard is newSafeDialContext which checks IPs at connect time.
|
||||
func isObviousPrivateHost(host string) bool {
|
||||
if allowPrivateWebFetchHosts.Load() {
|
||||
return false
|
||||
}
|
||||
|
||||
h := strings.ToLower(strings.TrimSpace(host))
|
||||
h = strings.TrimSuffix(h, ".")
|
||||
if h == "" {
|
||||
return true
|
||||
}
|
||||
|
||||
if h == "localhost" || strings.HasSuffix(h, ".localhost") {
|
||||
return true
|
||||
}
|
||||
|
||||
if ip := net.ParseIP(h); ip != nil {
|
||||
return isPrivateOrRestrictedIP(ip)
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// isPrivateOrRestrictedIP returns true for IPs that should never be reached via web_fetch:
|
||||
// RFC 1918, loopback, link-local (incl. cloud metadata 169.254.x.x), carrier-grade NAT,
|
||||
// IPv6 unique-local (fc00::/7), 6to4 (2002::/16), and Teredo (2001:0000::/32).
|
||||
func isPrivateOrRestrictedIP(ip net.IP) bool {
|
||||
if ip == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() ||
|
||||
ip.IsMulticast() || ip.IsUnspecified() {
|
||||
return true
|
||||
}
|
||||
|
||||
if ip4 := ip.To4(); ip4 != nil {
|
||||
// IPv4 private, loopback, link-local, and carrier-grade NAT ranges.
|
||||
if ip4[0] == 10 ||
|
||||
ip4[0] == 127 ||
|
||||
ip4[0] == 0 ||
|
||||
(ip4[0] == 172 && ip4[1] >= 16 && ip4[1] <= 31) ||
|
||||
(ip4[0] == 192 && ip4[1] == 168) ||
|
||||
(ip4[0] == 169 && ip4[1] == 254) ||
|
||||
(ip4[0] == 100 && ip4[1] >= 64 && ip4[1] <= 127) {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
if len(ip) == net.IPv6len {
|
||||
// IPv6 unique local addresses (fc00::/7)
|
||||
if (ip[0] & 0xfe) == 0xfc {
|
||||
return true
|
||||
}
|
||||
// 6to4 addresses (2002::/16): check the embedded IPv4 at bytes [2:6].
|
||||
if ip[0] == 0x20 && ip[1] == 0x02 {
|
||||
embedded := net.IPv4(ip[2], ip[3], ip[4], ip[5])
|
||||
return isPrivateOrRestrictedIP(embedded)
|
||||
}
|
||||
// Teredo (2001:0000::/32): client IPv4 is at bytes [12:16], XOR-inverted.
|
||||
if ip[0] == 0x20 && ip[1] == 0x01 && ip[2] == 0x00 && ip[3] == 0x00 {
|
||||
client := net.IPv4(ip[12]^0xff, ip[13]^0xff, ip[14]^0xff, ip[15]^0xff)
|
||||
return isPrivateOrRestrictedIP(client)
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user