mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
Merge branch 'main' into feat/markdown-output-format-web-fetch
This commit is contained in:
@@ -10,6 +10,7 @@ import (
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
)
|
||||
@@ -425,6 +426,29 @@ func withPrivateWebFetchHostsAllowed(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func serverHostAndPort(t *testing.T, rawURL string) (string, string) {
|
||||
t.Helper()
|
||||
hostPort := strings.TrimPrefix(rawURL, "http://")
|
||||
hostPort = strings.TrimPrefix(hostPort, "https://")
|
||||
host, port, err := net.SplitHostPort(hostPort)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to split host/port from %q: %v", rawURL, err)
|
||||
}
|
||||
return host, port
|
||||
}
|
||||
|
||||
func singleHostCIDR(t *testing.T, host string) string {
|
||||
t.Helper()
|
||||
ip := net.ParseIP(host)
|
||||
if ip == nil {
|
||||
t.Fatalf("failed to parse IP %q", host)
|
||||
}
|
||||
if ip.To4() != nil {
|
||||
return ip.String() + "/32"
|
||||
}
|
||||
return ip.String() + "/128"
|
||||
}
|
||||
|
||||
func TestWebTool_WebFetch_PrivateHostBlocked(t *testing.T) {
|
||||
tool, err := NewWebFetchTool(50000, format, testFetchLimit)
|
||||
if err != nil {
|
||||
@@ -443,6 +467,56 @@ func TestWebTool_WebFetch_PrivateHostBlocked(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebTool_WebFetch_PrivateHostAllowedByExactWhitelist(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("exact whitelist ok"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
host, _ := serverHostAndPort(t, server.URL)
|
||||
tool, err := NewWebFetchToolWithConfig(50000, "", testFetchLimit, []string{host})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create web fetch tool: %v", err)
|
||||
}
|
||||
|
||||
result := tool.Execute(context.Background(), map[string]any{
|
||||
"url": server.URL,
|
||||
})
|
||||
if result.IsError {
|
||||
t.Fatalf("expected success for exact whitelisted private IP, got %q", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "exact whitelist ok") {
|
||||
t.Fatalf("expected fetched content, got %q", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebTool_WebFetch_PrivateHostAllowedByCIDRWhitelist(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("cidr whitelist ok"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
host, _ := serverHostAndPort(t, server.URL)
|
||||
tool, err := NewWebFetchToolWithConfig(50000, "", testFetchLimit, []string{singleHostCIDR(t, host)})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create web fetch tool: %v", err)
|
||||
}
|
||||
|
||||
result := tool.Execute(context.Background(), map[string]any{
|
||||
"url": server.URL,
|
||||
})
|
||||
if result.IsError {
|
||||
t.Fatalf("expected success for CIDR-whitelisted private IP, got %q", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "cidr whitelist ok") {
|
||||
t.Fatalf("expected fetched content, got %q", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebTool_WebFetch_PrivateHostAllowedForTests(t *testing.T) {
|
||||
withPrivateWebFetchHostsAllowed(t)
|
||||
|
||||
@@ -572,6 +646,69 @@ func TestWebFetch_RedirectToPrivateBlocked(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewSafeDialContext_BlocksPrivateDNSResolutionWithoutWhitelist(t *testing.T) {
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to listen on loopback: %v", err)
|
||||
}
|
||||
defer listener.Close()
|
||||
|
||||
_, port, err := net.SplitHostPort(listener.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatalf("failed to split listener address: %v", err)
|
||||
}
|
||||
|
||||
dialContext := newSafeDialContext(&net.Dialer{Timeout: time.Second}, nil)
|
||||
_, err = dialContext(context.Background(), "tcp", net.JoinHostPort("localhost", port))
|
||||
if err == nil {
|
||||
t.Fatal("expected localhost DNS resolution to be blocked without whitelist")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "private") && !strings.Contains(err.Error(), "whitelisted") {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewSafeDialContext_AllowsWhitelistedPrivateDNSResolution(t *testing.T) {
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to listen on loopback: %v", err)
|
||||
}
|
||||
defer listener.Close()
|
||||
|
||||
accepted := make(chan struct{}, 1)
|
||||
go func() {
|
||||
conn, acceptErr := listener.Accept()
|
||||
if acceptErr != nil {
|
||||
return
|
||||
}
|
||||
conn.Close()
|
||||
accepted <- struct{}{}
|
||||
}()
|
||||
|
||||
_, port, err := net.SplitHostPort(listener.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatalf("failed to split listener address: %v", err)
|
||||
}
|
||||
|
||||
whitelist, err := newPrivateHostWhitelist([]string{"127.0.0.0/8"})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse whitelist: %v", err)
|
||||
}
|
||||
|
||||
dialContext := newSafeDialContext(&net.Dialer{Timeout: time.Second}, whitelist)
|
||||
conn, err := dialContext(context.Background(), "tcp", net.JoinHostPort("localhost", port))
|
||||
if err != nil {
|
||||
t.Fatalf("expected localhost DNS resolution to succeed with whitelist, got %v", err)
|
||||
}
|
||||
conn.Close()
|
||||
|
||||
select {
|
||||
case <-accepted:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("expected localhost listener to accept a connection")
|
||||
}
|
||||
}
|
||||
|
||||
// TestIsPrivateOrRestrictedIP_Table tests IP classification logic
|
||||
func TestIsPrivateOrRestrictedIP_Table(t *testing.T) {
|
||||
tests := []struct {
|
||||
@@ -662,6 +799,16 @@ func TestNewWebFetchToolWithProxy(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewWebFetchToolWithConfig_InvalidPrivateHostWhitelist(t *testing.T) {
|
||||
_, err := NewWebFetchToolWithConfig(1024, "", testFetchLimit, []string{"not-an-ip-or-cidr"})
|
||||
if err == nil {
|
||||
t.Fatal("expected invalid whitelist entry to fail")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "invalid entry") {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewWebSearchTool_PropagatesProxy(t *testing.T) {
|
||||
t.Run("perplexity", func(t *testing.T) {
|
||||
tool, err := NewWebSearchTool(WebSearchToolOptions{
|
||||
|
||||
Reference in New Issue
Block a user