mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
Merge branch 'main' into feat/markdown-output-format-web-fetch
This commit is contained in:
+95
-10
@@ -780,11 +780,17 @@ type WebFetchTool struct {
|
||||
client *http.Client
|
||||
format string
|
||||
fetchLimitBytes int64
|
||||
whitelist *privateHostWhitelist
|
||||
}
|
||||
|
||||
type privateHostWhitelist struct {
|
||||
exact map[string]struct{}
|
||||
cidrs []*net.IPNet
|
||||
}
|
||||
|
||||
func NewWebFetchTool(maxChars int, format string, fetchLimitBytes int64) (*WebFetchTool, error) {
|
||||
// createHTTPClient cannot fail with an empty proxy string.
|
||||
return NewWebFetchToolWithProxy(maxChars, "", format, fetchLimitBytes)
|
||||
return NewWebFetchToolWithProxy(maxChars, "", format, fetchLimitBytes, nil)
|
||||
}
|
||||
|
||||
// allowPrivateWebFetchHosts controls whether loopback/private hosts are allowed.
|
||||
@@ -792,9 +798,22 @@ func NewWebFetchTool(maxChars int, format string, fetchLimitBytes int64) (*WebFe
|
||||
var allowPrivateWebFetchHosts atomic.Bool
|
||||
|
||||
func NewWebFetchToolWithProxy(maxChars int, proxy string, format string, fetchLimitBytes int64) (*WebFetchTool, error) {
|
||||
return NewWebFetchToolWithConfig(maxChars, proxy, fetchLimitBytes, nil)
|
||||
}
|
||||
|
||||
func NewWebFetchToolWithConfig(
|
||||
maxChars int,
|
||||
proxy string,
|
||||
fetchLimitBytes int64,
|
||||
privateHostWhitelist []string,
|
||||
) (*WebFetchTool, error) {
|
||||
if maxChars <= 0 {
|
||||
maxChars = defaultMaxChars
|
||||
}
|
||||
whitelist, err := newPrivateHostWhitelist(privateHostWhitelist)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse web fetch private host whitelist: %w", err)
|
||||
}
|
||||
client, err := utils.CreateHTTPClient(proxy, fetchTimeout)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create HTTP client for web fetch: %w", err)
|
||||
@@ -804,13 +823,13 @@ func NewWebFetchToolWithProxy(maxChars int, proxy string, format string, fetchLi
|
||||
Timeout: 15 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
}
|
||||
transport.DialContext = newSafeDialContext(dialer)
|
||||
transport.DialContext = newSafeDialContext(dialer, whitelist)
|
||||
}
|
||||
client.CheckRedirect = func(req *http.Request, via []*http.Request) error {
|
||||
if len(via) >= maxRedirects {
|
||||
return fmt.Errorf("stopped after %d redirects", maxRedirects)
|
||||
}
|
||||
if isObviousPrivateHost(req.URL.Hostname()) {
|
||||
if isObviousPrivateHost(req.URL.Hostname(), whitelist) {
|
||||
return fmt.Errorf("redirect target is private or local network host")
|
||||
}
|
||||
return nil
|
||||
@@ -824,6 +843,7 @@ func NewWebFetchToolWithProxy(maxChars int, proxy string, format string, fetchLi
|
||||
client: client,
|
||||
format: format,
|
||||
fetchLimitBytes: fetchLimitBytes,
|
||||
whitelist: whitelist,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -875,7 +895,7 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]any) *ToolRe
|
||||
// Lightweight pre-flight: block obvious localhost/literal-IP without DNS resolution.
|
||||
// The real SSRF guard is newSafeDialContext at connect time.
|
||||
hostname := parsedURL.Hostname()
|
||||
if isObviousPrivateHost(hostname) {
|
||||
if isObviousPrivateHost(hostname, t.whitelist) {
|
||||
return ErrorResult("fetching private or local network hosts is not allowed")
|
||||
}
|
||||
|
||||
@@ -1019,7 +1039,10 @@ func (t *WebFetchTool) extractText(htmlContent string) string {
|
||||
|
||||
// newSafeDialContext re-resolves DNS at connect time to mitigate DNS rebinding (TOCTOU)
|
||||
// where a hostname resolves to a public IP during pre-flight but a private IP at connect time.
|
||||
func newSafeDialContext(dialer *net.Dialer) func(context.Context, string, string) (net.Conn, error) {
|
||||
func newSafeDialContext(
|
||||
dialer *net.Dialer,
|
||||
whitelist *privateHostWhitelist,
|
||||
) func(context.Context, string, string) (net.Conn, error) {
|
||||
return func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
if allowPrivateWebFetchHosts.Load() {
|
||||
return dialer.DialContext(ctx, network, address)
|
||||
@@ -1034,7 +1057,7 @@ func newSafeDialContext(dialer *net.Dialer) func(context.Context, string, string
|
||||
}
|
||||
|
||||
if ip := net.ParseIP(host); ip != nil {
|
||||
if isPrivateOrRestrictedIP(ip) {
|
||||
if shouldBlockPrivateIP(ip, whitelist) {
|
||||
return nil, fmt.Errorf("blocked private or local target: %s", host)
|
||||
}
|
||||
return dialer.DialContext(ctx, network, net.JoinHostPort(ip.String(), port))
|
||||
@@ -1048,7 +1071,7 @@ func newSafeDialContext(dialer *net.Dialer) func(context.Context, string, string
|
||||
attempted := 0
|
||||
var lastErr error
|
||||
for _, ipAddr := range ipAddrs {
|
||||
if isPrivateOrRestrictedIP(ipAddr.IP) {
|
||||
if shouldBlockPrivateIP(ipAddr.IP, whitelist) {
|
||||
continue
|
||||
}
|
||||
attempted++
|
||||
@@ -1060,7 +1083,7 @@ func newSafeDialContext(dialer *net.Dialer) func(context.Context, string, string
|
||||
}
|
||||
|
||||
if attempted == 0 {
|
||||
return nil, fmt.Errorf("all resolved addresses for %s are private or restricted", host)
|
||||
return nil, fmt.Errorf("all resolved addresses for %s are private, restricted, or not whitelisted", host)
|
||||
}
|
||||
if lastErr != nil {
|
||||
return nil, fmt.Errorf("failed connecting to public addresses for %s: %w", host, lastErr)
|
||||
@@ -1069,10 +1092,72 @@ func newSafeDialContext(dialer *net.Dialer) func(context.Context, string, string
|
||||
}
|
||||
}
|
||||
|
||||
func newPrivateHostWhitelist(entries []string) (*privateHostWhitelist, error) {
|
||||
if len(entries) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
whitelist := &privateHostWhitelist{
|
||||
exact: make(map[string]struct{}),
|
||||
cidrs: make([]*net.IPNet, 0, len(entries)),
|
||||
}
|
||||
for _, entry := range entries {
|
||||
entry = strings.TrimSpace(entry)
|
||||
if entry == "" {
|
||||
continue
|
||||
}
|
||||
if ip := net.ParseIP(entry); ip != nil {
|
||||
whitelist.exact[normalizeWhitelistIP(ip).String()] = struct{}{}
|
||||
continue
|
||||
}
|
||||
_, network, err := net.ParseCIDR(entry)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid entry %q: expected IP or CIDR", entry)
|
||||
}
|
||||
whitelist.cidrs = append(whitelist.cidrs, network)
|
||||
}
|
||||
|
||||
if len(whitelist.exact) == 0 && len(whitelist.cidrs) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
return whitelist, nil
|
||||
}
|
||||
|
||||
func (w *privateHostWhitelist) Contains(ip net.IP) bool {
|
||||
if w == nil || ip == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
normalized := normalizeWhitelistIP(ip)
|
||||
if _, ok := w.exact[normalized.String()]; ok {
|
||||
return true
|
||||
}
|
||||
for _, network := range w.cidrs {
|
||||
if network.Contains(normalized) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func normalizeWhitelistIP(ip net.IP) net.IP {
|
||||
if ip == nil {
|
||||
return nil
|
||||
}
|
||||
if ip4 := ip.To4(); ip4 != nil {
|
||||
return ip4
|
||||
}
|
||||
return ip
|
||||
}
|
||||
|
||||
func shouldBlockPrivateIP(ip net.IP, whitelist *privateHostWhitelist) bool {
|
||||
return isPrivateOrRestrictedIP(ip) && !whitelist.Contains(ip)
|
||||
}
|
||||
|
||||
// isObviousPrivateHost performs a lightweight, no-DNS check for obviously private hosts.
|
||||
// It catches localhost, literal private IPs, and empty hosts. It does NOT resolve DNS —
|
||||
// the real SSRF guard is newSafeDialContext which checks IPs at connect time.
|
||||
func isObviousPrivateHost(host string) bool {
|
||||
func isObviousPrivateHost(host string, whitelist *privateHostWhitelist) bool {
|
||||
if allowPrivateWebFetchHosts.Load() {
|
||||
return false
|
||||
}
|
||||
@@ -1088,7 +1173,7 @@ func isObviousPrivateHost(host string) bool {
|
||||
}
|
||||
|
||||
if ip := net.ParseIP(h); ip != nil {
|
||||
return isPrivateOrRestrictedIP(ip)
|
||||
return shouldBlockPrivateIP(ip, whitelist)
|
||||
}
|
||||
|
||||
return false
|
||||
|
||||
Reference in New Issue
Block a user