Merge branch 'upstream-main' into feat/subturn-poc

This commit is contained in:
Administrator
2026-03-18 22:57:01 +08:00
117 changed files with 14857 additions and 7091 deletions
+6 -3
View File
@@ -226,9 +226,12 @@ func TestCronTool_ExecuteJobPublishesErrorWhenExecDisabled(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
msg, ok := tool.msgBus.SubscribeOutbound(ctx)
if !ok {
t.Fatal("expected outbound message")
var msg bus.OutboundMessage
select {
case msg = <-tool.msgBus.OutboundChan():
// got message
case <-ctx.Done():
t.Fatal("timeout waiting for outbound message")
}
if !strings.Contains(msg.Content, "command execution is disabled") {
t.Fatalf("expected exec disabled message, got: %s", msg.Content)
+174 -25
View File
@@ -7,6 +7,7 @@ import (
"errors"
"fmt"
"io"
"mime"
"net"
"net/http"
"net/url"
@@ -15,6 +16,7 @@ import (
"sync/atomic"
"time"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/utils"
)
@@ -776,22 +778,49 @@ type WebFetchTool struct {
maxChars int
proxy string
client *http.Client
format string
fetchLimitBytes int64
whitelist *privateHostWhitelist
}
func NewWebFetchTool(maxChars int, fetchLimitBytes int64) (*WebFetchTool, error) {
type privateHostWhitelist struct {
exact map[string]struct{}
cidrs []*net.IPNet
}
func NewWebFetchTool(maxChars int, format string, fetchLimitBytes int64) (*WebFetchTool, error) {
// createHTTPClient cannot fail with an empty proxy string.
return NewWebFetchToolWithProxy(maxChars, "", fetchLimitBytes)
return NewWebFetchToolWithConfig(maxChars, "", format, fetchLimitBytes, nil)
}
// allowPrivateWebFetchHosts controls whether loopback/private hosts are allowed.
// This is false in normal runtime to reduce SSRF exposure, and tests can override it temporarily.
var allowPrivateWebFetchHosts atomic.Bool
func NewWebFetchToolWithProxy(maxChars int, proxy string, fetchLimitBytes int64) (*WebFetchTool, error) {
func NewWebFetchToolWithProxy(
maxChars int,
proxy string,
format string,
fetchLimitBytes int64,
privateHostWhitelist []string,
) (*WebFetchTool, error) {
return NewWebFetchToolWithConfig(maxChars, proxy, format, fetchLimitBytes, privateHostWhitelist)
}
func NewWebFetchToolWithConfig(
maxChars int,
proxy string,
format string,
fetchLimitBytes int64,
privateHostWhitelist []string,
) (*WebFetchTool, error) {
if maxChars <= 0 {
maxChars = defaultMaxChars
}
whitelist, err := newPrivateHostWhitelist(privateHostWhitelist)
if err != nil {
return nil, fmt.Errorf("failed to parse web fetch private host whitelist: %w", err)
}
client, err := utils.CreateHTTPClient(proxy, fetchTimeout)
if err != nil {
return nil, fmt.Errorf("failed to create HTTP client for web fetch: %w", err)
@@ -801,13 +830,13 @@ func NewWebFetchToolWithProxy(maxChars int, proxy string, fetchLimitBytes int64)
Timeout: 15 * time.Second,
KeepAlive: 30 * time.Second,
}
transport.DialContext = newSafeDialContext(dialer)
transport.DialContext = newSafeDialContext(dialer, whitelist)
}
client.CheckRedirect = func(req *http.Request, via []*http.Request) error {
if len(via) >= maxRedirects {
return fmt.Errorf("stopped after %d redirects", maxRedirects)
}
if isObviousPrivateHost(req.URL.Hostname()) {
if isObviousPrivateHost(req.URL.Hostname(), whitelist) {
return fmt.Errorf("redirect target is private or local network host")
}
return nil
@@ -819,7 +848,9 @@ func NewWebFetchToolWithProxy(maxChars int, proxy string, fetchLimitBytes int64)
maxChars: maxChars,
proxy: proxy,
client: client,
format: format,
fetchLimitBytes: fetchLimitBytes,
whitelist: whitelist,
}, nil
}
@@ -871,7 +902,7 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]any) *ToolRe
// Lightweight pre-flight: block obvious localhost/literal-IP without DNS resolution.
// The real SSRF guard is newSafeDialContext at connect time.
hostname := parsedURL.Hostname()
if isObviousPrivateHost(hostname) {
if isObviousPrivateHost(hostname, t.whitelist) {
return ErrorResult("fetching private or local network hosts is not allowed")
}
@@ -906,26 +937,68 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]any) *ToolRe
return ErrorResult(fmt.Sprintf("failed to read response: %v", err))
}
bodyStr := string(body)
contentType := resp.Header.Get("Content-Type")
mediaType, params, err := mime.ParseMediaType(contentType)
if err != nil {
// The most common error here is "mime: no media type" if the header is empty.
logger.WarnCF("tool", "Failed to parse Content-Type", map[string]any{
"raw_header": contentType,
"error": err.Error(),
})
// security fallback
mediaType = "application/octet-stream"
}
charset, hasCharset := params["charset"]
if hasCharset {
// If the charset is not utf-8, we might have to convert the bodyStr
// before passing it to the HTML/Markdown parser
if strings.ToLower(charset) != "utf-8" {
logger.WarnCF("tool", "Note: the content is not in UTF-8", map[string]any{"charset": charset})
}
}
var text, extractor string
if strings.Contains(contentType, "application/json") {
switch {
case mediaType == "application/json":
var jsonData any
if err := json.Unmarshal(body, &jsonData); err == nil {
formatted, _ := json.MarshalIndent(jsonData, "", " ")
text = string(formatted)
extractor = "json"
} else {
text = string(body)
if err := json.Unmarshal(body, &jsonData); err != nil {
text = bodyStr
extractor = "raw"
break
}
} else if strings.Contains(contentType, "text/html") || len(body) > 0 &&
(strings.HasPrefix(string(body), "<!DOCTYPE") || strings.HasPrefix(strings.ToLower(string(body)), "<html")) {
text = t.extractText(string(body))
extractor = "text"
} else {
text = string(body)
formatted, err := json.MarshalIndent(jsonData, "", " ")
if err != nil {
text = bodyStr
extractor = "raw"
break
}
text = string(formatted)
extractor = "json"
case mediaType == "text/html" || looksLikeHTML(bodyStr):
switch strings.ToLower(t.format) {
case "markdown":
var err error
text, err = utils.HtmlToMarkdown(bodyStr)
if err != nil {
return ErrorResult(fmt.Sprintf("failed to HTML to markdown: %v", err))
}
extractor = "markdown"
default:
text = t.extractText(bodyStr)
extractor = "text"
}
default:
text = bodyStr
extractor = "raw"
}
@@ -957,6 +1030,17 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]any) *ToolRe
}
}
func looksLikeHTML(body string) bool {
if body == "" {
return false
}
lower := strings.ToLower(body)
return strings.HasPrefix(body, "<!doctype") ||
strings.HasPrefix(lower, "<html")
}
func (t *WebFetchTool) extractText(htmlContent string) string {
result := reScript.ReplaceAllLiteralString(htmlContent, "")
result = reStyle.ReplaceAllLiteralString(result, "")
@@ -981,7 +1065,10 @@ func (t *WebFetchTool) extractText(htmlContent string) string {
// newSafeDialContext re-resolves DNS at connect time to mitigate DNS rebinding (TOCTOU)
// where a hostname resolves to a public IP during pre-flight but a private IP at connect time.
func newSafeDialContext(dialer *net.Dialer) func(context.Context, string, string) (net.Conn, error) {
func newSafeDialContext(
dialer *net.Dialer,
whitelist *privateHostWhitelist,
) func(context.Context, string, string) (net.Conn, error) {
return func(ctx context.Context, network, address string) (net.Conn, error) {
if allowPrivateWebFetchHosts.Load() {
return dialer.DialContext(ctx, network, address)
@@ -996,7 +1083,7 @@ func newSafeDialContext(dialer *net.Dialer) func(context.Context, string, string
}
if ip := net.ParseIP(host); ip != nil {
if isPrivateOrRestrictedIP(ip) {
if shouldBlockPrivateIP(ip, whitelist) {
return nil, fmt.Errorf("blocked private or local target: %s", host)
}
return dialer.DialContext(ctx, network, net.JoinHostPort(ip.String(), port))
@@ -1010,7 +1097,7 @@ func newSafeDialContext(dialer *net.Dialer) func(context.Context, string, string
attempted := 0
var lastErr error
for _, ipAddr := range ipAddrs {
if isPrivateOrRestrictedIP(ipAddr.IP) {
if shouldBlockPrivateIP(ipAddr.IP, whitelist) {
continue
}
attempted++
@@ -1022,7 +1109,7 @@ func newSafeDialContext(dialer *net.Dialer) func(context.Context, string, string
}
if attempted == 0 {
return nil, fmt.Errorf("all resolved addresses for %s are private or restricted", host)
return nil, fmt.Errorf("all resolved addresses for %s are private, restricted, or not whitelisted", host)
}
if lastErr != nil {
return nil, fmt.Errorf("failed connecting to public addresses for %s: %w", host, lastErr)
@@ -1031,10 +1118,72 @@ func newSafeDialContext(dialer *net.Dialer) func(context.Context, string, string
}
}
func newPrivateHostWhitelist(entries []string) (*privateHostWhitelist, error) {
if len(entries) == 0 {
return nil, nil
}
whitelist := &privateHostWhitelist{
exact: make(map[string]struct{}),
cidrs: make([]*net.IPNet, 0, len(entries)),
}
for _, entry := range entries {
entry = strings.TrimSpace(entry)
if entry == "" {
continue
}
if ip := net.ParseIP(entry); ip != nil {
whitelist.exact[normalizeWhitelistIP(ip).String()] = struct{}{}
continue
}
_, network, err := net.ParseCIDR(entry)
if err != nil {
return nil, fmt.Errorf("invalid entry %q: expected IP or CIDR", entry)
}
whitelist.cidrs = append(whitelist.cidrs, network)
}
if len(whitelist.exact) == 0 && len(whitelist.cidrs) == 0 {
return nil, nil
}
return whitelist, nil
}
func (w *privateHostWhitelist) Contains(ip net.IP) bool {
if w == nil || ip == nil {
return false
}
normalized := normalizeWhitelistIP(ip)
if _, ok := w.exact[normalized.String()]; ok {
return true
}
for _, network := range w.cidrs {
if network.Contains(normalized) {
return true
}
}
return false
}
func normalizeWhitelistIP(ip net.IP) net.IP {
if ip == nil {
return nil
}
if ip4 := ip.To4(); ip4 != nil {
return ip4
}
return ip
}
func shouldBlockPrivateIP(ip net.IP, whitelist *privateHostWhitelist) bool {
return isPrivateOrRestrictedIP(ip) && !whitelist.Contains(ip)
}
// isObviousPrivateHost performs a lightweight, no-DNS check for obviously private hosts.
// It catches localhost, literal private IPs, and empty hosts. It does NOT resolve DNS —
// the real SSRF guard is newSafeDialContext which checks IPs at connect time.
func isObviousPrivateHost(host string) bool {
func isObviousPrivateHost(host string, whitelist *privateHostWhitelist) bool {
if allowPrivateWebFetchHosts.Load() {
return false
}
@@ -1050,7 +1199,7 @@ func isObviousPrivateHost(host string) bool {
}
if ip := net.ParseIP(h); ip != nil {
return isPrivateOrRestrictedIP(ip)
return shouldBlockPrivateIP(ip, whitelist)
}
return false
+170 -20
View File
@@ -10,11 +10,15 @@ import (
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/sipeed/picoclaw/pkg/logger"
)
const testFetchLimit = int64(10 * 1024 * 1024)
const (
testFetchLimit = int64(10 * 1024 * 1024)
format = "plaintext"
)
// TestWebTool_WebFetch_Success verifies successful URL fetching
func TestWebTool_WebFetch_Success(t *testing.T) {
@@ -27,7 +31,7 @@ func TestWebTool_WebFetch_Success(t *testing.T) {
}))
defer server.Close()
tool, err := NewWebFetchTool(50000, testFetchLimit)
tool, err := NewWebFetchTool(50000, format, testFetchLimit)
if err != nil {
t.Fatalf("Failed to create web fetch tool: %v", err)
}
@@ -69,7 +73,7 @@ func TestWebTool_WebFetch_JSON(t *testing.T) {
}))
defer server.Close()
tool, err := NewWebFetchTool(50000, testFetchLimit)
tool, err := NewWebFetchTool(50000, format, testFetchLimit)
if err != nil {
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
}
@@ -94,7 +98,7 @@ func TestWebTool_WebFetch_JSON(t *testing.T) {
// TestWebTool_WebFetch_InvalidURL verifies error handling for invalid URL
func TestWebTool_WebFetch_InvalidURL(t *testing.T) {
tool, err := NewWebFetchTool(50000, testFetchLimit)
tool, err := NewWebFetchTool(50000, format, testFetchLimit)
if err != nil {
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
}
@@ -119,7 +123,7 @@ func TestWebTool_WebFetch_InvalidURL(t *testing.T) {
// TestWebTool_WebFetch_UnsupportedScheme verifies error handling for non-http URLs
func TestWebTool_WebFetch_UnsupportedScheme(t *testing.T) {
tool, err := NewWebFetchTool(50000, testFetchLimit)
tool, err := NewWebFetchTool(50000, format, testFetchLimit)
if err != nil {
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
}
@@ -144,7 +148,7 @@ func TestWebTool_WebFetch_UnsupportedScheme(t *testing.T) {
// TestWebTool_WebFetch_MissingURL verifies error handling for missing URL
func TestWebTool_WebFetch_MissingURL(t *testing.T) {
tool, err := NewWebFetchTool(50000, testFetchLimit)
tool, err := NewWebFetchTool(50000, format, testFetchLimit)
if err != nil {
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
}
@@ -178,7 +182,7 @@ func TestWebTool_WebFetch_Truncation(t *testing.T) {
}))
defer server.Close()
tool, err := NewWebFetchTool(1000, testFetchLimit) // Limit to 1000 chars
tool, err := NewWebFetchTool(1000, format, testFetchLimit) // Limit to 1000 chars
if err != nil {
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
}
@@ -228,7 +232,7 @@ func TestWebFetchTool_PayloadTooLarge(t *testing.T) {
defer ts.Close()
// Initialize the tool
tool, err := NewWebFetchTool(50000, testFetchLimit)
tool, err := NewWebFetchTool(50000, format, testFetchLimit)
if err != nil {
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
}
@@ -311,7 +315,7 @@ func TestWebTool_WebFetch_HTMLExtraction(t *testing.T) {
}))
defer server.Close()
tool, err := NewWebFetchTool(50000, testFetchLimit)
tool, err := NewWebFetchTool(50000, format, testFetchLimit)
if err != nil {
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
}
@@ -423,8 +427,31 @@ func withPrivateWebFetchHostsAllowed(t *testing.T) {
})
}
func serverHostAndPort(t *testing.T, rawURL string) (string, string) {
t.Helper()
hostPort := strings.TrimPrefix(rawURL, "http://")
hostPort = strings.TrimPrefix(hostPort, "https://")
host, port, err := net.SplitHostPort(hostPort)
if err != nil {
t.Fatalf("failed to split host/port from %q: %v", rawURL, err)
}
return host, port
}
func singleHostCIDR(t *testing.T, host string) string {
t.Helper()
ip := net.ParseIP(host)
if ip == nil {
t.Fatalf("failed to parse IP %q", host)
}
if ip.To4() != nil {
return ip.String() + "/32"
}
return ip.String() + "/128"
}
func TestWebTool_WebFetch_PrivateHostBlocked(t *testing.T) {
tool, err := NewWebFetchTool(50000, testFetchLimit)
tool, err := NewWebFetchTool(50000, format, testFetchLimit)
if err != nil {
t.Fatalf("Failed to create web fetch tool: %v", err)
}
@@ -441,6 +468,56 @@ func TestWebTool_WebFetch_PrivateHostBlocked(t *testing.T) {
}
}
func TestWebTool_WebFetch_PrivateHostAllowedByExactWhitelist(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain")
w.WriteHeader(http.StatusOK)
w.Write([]byte("exact whitelist ok"))
}))
defer server.Close()
host, _ := serverHostAndPort(t, server.URL)
tool, err := NewWebFetchToolWithConfig(50000, "", format, testFetchLimit, []string{host})
if err != nil {
t.Fatalf("Failed to create web fetch tool: %v", err)
}
result := tool.Execute(context.Background(), map[string]any{
"url": server.URL,
})
if result.IsError {
t.Fatalf("expected success for exact whitelisted private IP, got %q", result.ForLLM)
}
if !strings.Contains(result.ForLLM, "exact whitelist ok") {
t.Fatalf("expected fetched content, got %q", result.ForLLM)
}
}
func TestWebTool_WebFetch_PrivateHostAllowedByCIDRWhitelist(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain")
w.WriteHeader(http.StatusOK)
w.Write([]byte("cidr whitelist ok"))
}))
defer server.Close()
host, _ := serverHostAndPort(t, server.URL)
tool, err := NewWebFetchToolWithConfig(50000, "", format, testFetchLimit, []string{singleHostCIDR(t, host)})
if err != nil {
t.Fatalf("Failed to create web fetch tool: %v", err)
}
result := tool.Execute(context.Background(), map[string]any{
"url": server.URL,
})
if result.IsError {
t.Fatalf("expected success for CIDR-whitelisted private IP, got %q", result.ForLLM)
}
if !strings.Contains(result.ForLLM, "cidr whitelist ok") {
t.Fatalf("expected fetched content, got %q", result.ForLLM)
}
}
func TestWebTool_WebFetch_PrivateHostAllowedForTests(t *testing.T) {
withPrivateWebFetchHostsAllowed(t)
@@ -451,7 +528,7 @@ func TestWebTool_WebFetch_PrivateHostAllowedForTests(t *testing.T) {
}))
defer server.Close()
tool, err := NewWebFetchTool(50000, testFetchLimit)
tool, err := NewWebFetchTool(50000, format, testFetchLimit)
if err != nil {
t.Fatalf("Failed to create web fetch tool: %v", err)
}
@@ -466,7 +543,7 @@ func TestWebTool_WebFetch_PrivateHostAllowedForTests(t *testing.T) {
// TestWebFetch_BlocksIPv4MappedIPv6Loopback verifies ::ffff:127.0.0.1 is blocked
func TestWebFetch_BlocksIPv4MappedIPv6Loopback(t *testing.T) {
tool, err := NewWebFetchTool(50000, testFetchLimit)
tool, err := NewWebFetchTool(50000, format, testFetchLimit)
if err != nil {
t.Fatalf("Failed to create web fetch tool: %v", err)
}
@@ -481,7 +558,7 @@ func TestWebFetch_BlocksIPv4MappedIPv6Loopback(t *testing.T) {
// TestWebFetch_BlocksMetadataIP verifies 169.254.169.254 is blocked
func TestWebFetch_BlocksMetadataIP(t *testing.T) {
tool, err := NewWebFetchTool(50000, testFetchLimit)
tool, err := NewWebFetchTool(50000, format, testFetchLimit)
if err != nil {
t.Fatalf("Failed to create web fetch tool: %v", err)
}
@@ -496,7 +573,7 @@ func TestWebFetch_BlocksMetadataIP(t *testing.T) {
// TestWebFetch_BlocksIPv6UniqueLocal verifies fc00::/7 addresses are blocked
func TestWebFetch_BlocksIPv6UniqueLocal(t *testing.T) {
tool, err := NewWebFetchTool(50000, testFetchLimit)
tool, err := NewWebFetchTool(50000, format, testFetchLimit)
if err != nil {
t.Fatalf("Failed to create web fetch tool: %v", err)
}
@@ -511,7 +588,7 @@ func TestWebFetch_BlocksIPv6UniqueLocal(t *testing.T) {
// TestWebFetch_Blocks6to4WithPrivateEmbed verifies 6to4 with private embedded IPv4 is blocked
func TestWebFetch_Blocks6to4WithPrivateEmbed(t *testing.T) {
tool, err := NewWebFetchTool(50000, testFetchLimit)
tool, err := NewWebFetchTool(50000, format, testFetchLimit)
if err != nil {
t.Fatalf("Failed to create web fetch tool: %v", err)
}
@@ -527,7 +604,7 @@ func TestWebFetch_Blocks6to4WithPrivateEmbed(t *testing.T) {
// TestWebFetch_Allows6to4WithPublicEmbed verifies 6to4 with public embedded IPv4 is NOT blocked
func TestWebFetch_Allows6to4WithPublicEmbed(t *testing.T) {
tool, err := NewWebFetchTool(50000, testFetchLimit)
tool, err := NewWebFetchTool(50000, format, testFetchLimit)
if err != nil {
t.Fatalf("Failed to create web fetch tool: %v", err)
}
@@ -557,7 +634,7 @@ func TestWebFetch_RedirectToPrivateBlocked(t *testing.T) {
allowPrivateWebFetchHosts.Store(false)
defer allowPrivateWebFetchHosts.Store(true)
tool, err := NewWebFetchTool(50000, testFetchLimit)
tool, err := NewWebFetchTool(50000, format, testFetchLimit)
if err != nil {
t.Fatalf("Failed to create web fetch tool: %v", err)
}
@@ -570,6 +647,69 @@ func TestWebFetch_RedirectToPrivateBlocked(t *testing.T) {
}
}
func TestNewSafeDialContext_BlocksPrivateDNSResolutionWithoutWhitelist(t *testing.T) {
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("failed to listen on loopback: %v", err)
}
defer listener.Close()
_, port, err := net.SplitHostPort(listener.Addr().String())
if err != nil {
t.Fatalf("failed to split listener address: %v", err)
}
dialContext := newSafeDialContext(&net.Dialer{Timeout: time.Second}, nil)
_, err = dialContext(context.Background(), "tcp", net.JoinHostPort("localhost", port))
if err == nil {
t.Fatal("expected localhost DNS resolution to be blocked without whitelist")
}
if !strings.Contains(err.Error(), "private") && !strings.Contains(err.Error(), "whitelisted") {
t.Fatalf("unexpected error: %v", err)
}
}
func TestNewSafeDialContext_AllowsWhitelistedPrivateDNSResolution(t *testing.T) {
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("failed to listen on loopback: %v", err)
}
defer listener.Close()
accepted := make(chan struct{}, 1)
go func() {
conn, acceptErr := listener.Accept()
if acceptErr != nil {
return
}
conn.Close()
accepted <- struct{}{}
}()
_, port, err := net.SplitHostPort(listener.Addr().String())
if err != nil {
t.Fatalf("failed to split listener address: %v", err)
}
whitelist, err := newPrivateHostWhitelist([]string{"127.0.0.0/8"})
if err != nil {
t.Fatalf("failed to parse whitelist: %v", err)
}
dialContext := newSafeDialContext(&net.Dialer{Timeout: time.Second}, whitelist)
conn, err := dialContext(context.Background(), "tcp", net.JoinHostPort("localhost", port))
if err != nil {
t.Fatalf("expected localhost DNS resolution to succeed with whitelist, got %v", err)
}
conn.Close()
select {
case <-accepted:
case <-time.After(time.Second):
t.Fatal("expected localhost listener to accept a connection")
}
}
// TestIsPrivateOrRestrictedIP_Table tests IP classification logic
func TestIsPrivateOrRestrictedIP_Table(t *testing.T) {
tests := []struct {
@@ -615,7 +755,7 @@ func TestIsPrivateOrRestrictedIP_Table(t *testing.T) {
// TestWebTool_WebFetch_MissingDomain verifies error handling for URL without domain
func TestWebTool_WebFetch_MissingDomain(t *testing.T) {
tool, err := NewWebFetchTool(50000, testFetchLimit)
tool, err := NewWebFetchTool(50000, format, testFetchLimit)
if err != nil {
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
}
@@ -639,7 +779,7 @@ func TestWebTool_WebFetch_MissingDomain(t *testing.T) {
}
func TestNewWebFetchToolWithProxy(t *testing.T) {
tool, err := NewWebFetchToolWithProxy(1024, "http://127.0.0.1:7890", testFetchLimit)
tool, err := NewWebFetchToolWithProxy(1024, "http://127.0.0.1:7890", format, testFetchLimit, nil)
if err != nil {
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
} else if tool.maxChars != 1024 {
@@ -650,7 +790,7 @@ func TestNewWebFetchToolWithProxy(t *testing.T) {
t.Fatalf("proxy = %q, want %q", tool.proxy, "http://127.0.0.1:7890")
}
tool, err = NewWebFetchToolWithProxy(0, "http://127.0.0.1:7890", testFetchLimit)
tool, err = NewWebFetchToolWithProxy(0, "http://127.0.0.1:7890", format, testFetchLimit, nil)
if err != nil {
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
}
@@ -660,6 +800,16 @@ func TestNewWebFetchToolWithProxy(t *testing.T) {
}
}
func TestNewWebFetchToolWithConfig_InvalidPrivateHostWhitelist(t *testing.T) {
_, err := NewWebFetchToolWithConfig(1024, "", format, testFetchLimit, []string{"not-an-ip-or-cidr"})
if err == nil {
t.Fatal("expected invalid whitelist entry to fail")
}
if !strings.Contains(err.Error(), "invalid entry") {
t.Fatalf("unexpected error: %v", err)
}
}
func TestNewWebSearchTool_PropagatesProxy(t *testing.T) {
t.Run("perplexity", func(t *testing.T) {
tool, err := NewWebSearchTool(WebSearchToolOptions{