mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
Merge branch 'main' into fix/max-payload-size-in-web-fetch
This commit is contained in:
+12
-6
@@ -3,6 +3,7 @@ package tools
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -33,15 +34,19 @@ type CronTool struct {
|
||||
func NewCronTool(
|
||||
cronService *cron.CronService, executor JobExecutor, msgBus *bus.MessageBus, workspace string, restrict bool,
|
||||
execTimeout time.Duration, config *config.Config,
|
||||
) *CronTool {
|
||||
execTool := NewExecToolWithConfig(workspace, restrict, config)
|
||||
) (*CronTool, error) {
|
||||
execTool, err := NewExecToolWithConfig(workspace, restrict, config)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to configure exec tool: %w", err)
|
||||
}
|
||||
|
||||
execTool.SetTimeout(execTimeout)
|
||||
return &CronTool{
|
||||
cronService: cronService,
|
||||
executor: executor,
|
||||
msgBus: msgBus,
|
||||
execTool: execTool,
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Name returns the tool name
|
||||
@@ -218,7 +223,8 @@ func (t *CronTool) listJobs() *ToolResult {
|
||||
return SilentResult("No scheduled jobs")
|
||||
}
|
||||
|
||||
result := "Scheduled jobs:\n"
|
||||
var result strings.Builder
|
||||
result.WriteString("Scheduled jobs:\n")
|
||||
for _, j := range jobs {
|
||||
var scheduleInfo string
|
||||
if j.Schedule.Kind == "every" && j.Schedule.EveryMS != nil {
|
||||
@@ -230,10 +236,10 @@ func (t *CronTool) listJobs() *ToolResult {
|
||||
} else {
|
||||
scheduleInfo = "unknown"
|
||||
}
|
||||
result += fmt.Sprintf("- %s (id: %s, %s)\n", j.Name, j.ID, scheduleInfo)
|
||||
result.WriteString(fmt.Sprintf("- %s (id: %s, %s)\n", j.Name, j.ID, scheduleInfo))
|
||||
}
|
||||
|
||||
return SilentResult(result)
|
||||
return SilentResult(result.String())
|
||||
}
|
||||
|
||||
func (t *CronTool) removeJob(args map[string]any) *ToolResult {
|
||||
|
||||
@@ -329,7 +329,7 @@ func TestToolRegistry_ConcurrentAccess(t *testing.T) {
|
||||
r := NewToolRegistry()
|
||||
var wg sync.WaitGroup
|
||||
|
||||
for i := 0; i < 50; i++ {
|
||||
for i := range 50 {
|
||||
wg.Add(1)
|
||||
go func(n int) {
|
||||
defer wg.Done()
|
||||
|
||||
+57
-51
@@ -24,56 +24,64 @@ type ExecTool struct {
|
||||
restrictToWorkspace bool
|
||||
}
|
||||
|
||||
var defaultDenyPatterns = []*regexp.Regexp{
|
||||
regexp.MustCompile(`\brm\s+-[rf]{1,2}\b`),
|
||||
regexp.MustCompile(`\bdel\s+/[fq]\b`),
|
||||
regexp.MustCompile(`\brmdir\s+/s\b`),
|
||||
regexp.MustCompile(`\b(format|mkfs|diskpart)\b\s`), // Match disk wiping commands (must be followed by space/args)
|
||||
regexp.MustCompile(`\bdd\s+if=`),
|
||||
regexp.MustCompile(`>\s*/dev/sd[a-z]\b`), // Block writes to disk devices (but allow /dev/null)
|
||||
regexp.MustCompile(`\b(shutdown|reboot|poweroff)\b`),
|
||||
regexp.MustCompile(`:\(\)\s*\{.*\};\s*:`),
|
||||
regexp.MustCompile(`\$\([^)]+\)`),
|
||||
regexp.MustCompile(`\$\{[^}]+\}`),
|
||||
regexp.MustCompile("`[^`]+`"),
|
||||
regexp.MustCompile(`\|\s*sh\b`),
|
||||
regexp.MustCompile(`\|\s*bash\b`),
|
||||
regexp.MustCompile(`;\s*rm\s+-[rf]`),
|
||||
regexp.MustCompile(`&&\s*rm\s+-[rf]`),
|
||||
regexp.MustCompile(`\|\|\s*rm\s+-[rf]`),
|
||||
regexp.MustCompile(`>\s*/dev/null\s*>&?\s*\d?`),
|
||||
regexp.MustCompile(`<<\s*EOF`),
|
||||
regexp.MustCompile(`\$\(\s*cat\s+`),
|
||||
regexp.MustCompile(`\$\(\s*curl\s+`),
|
||||
regexp.MustCompile(`\$\(\s*wget\s+`),
|
||||
regexp.MustCompile(`\$\(\s*which\s+`),
|
||||
regexp.MustCompile(`\bsudo\b`),
|
||||
regexp.MustCompile(`\bchmod\s+[0-7]{3,4}\b`),
|
||||
regexp.MustCompile(`\bchown\b`),
|
||||
regexp.MustCompile(`\bpkill\b`),
|
||||
regexp.MustCompile(`\bkillall\b`),
|
||||
regexp.MustCompile(`\bkill\s+-[9]\b`),
|
||||
regexp.MustCompile(`\bcurl\b.*\|\s*(sh|bash)`),
|
||||
regexp.MustCompile(`\bwget\b.*\|\s*(sh|bash)`),
|
||||
regexp.MustCompile(`\bnpm\s+install\s+-g\b`),
|
||||
regexp.MustCompile(`\bpip\s+install\s+--user\b`),
|
||||
regexp.MustCompile(`\bapt\s+(install|remove|purge)\b`),
|
||||
regexp.MustCompile(`\byum\s+(install|remove)\b`),
|
||||
regexp.MustCompile(`\bdnf\s+(install|remove)\b`),
|
||||
regexp.MustCompile(`\bdocker\s+run\b`),
|
||||
regexp.MustCompile(`\bdocker\s+exec\b`),
|
||||
regexp.MustCompile(`\bgit\s+push\b`),
|
||||
regexp.MustCompile(`\bgit\s+force\b`),
|
||||
regexp.MustCompile(`\bssh\b.*@`),
|
||||
regexp.MustCompile(`\beval\b`),
|
||||
regexp.MustCompile(`\bsource\s+.*\.sh\b`),
|
||||
}
|
||||
var (
|
||||
defaultDenyPatterns = []*regexp.Regexp{
|
||||
regexp.MustCompile(`\brm\s+-[rf]{1,2}\b`),
|
||||
regexp.MustCompile(`\bdel\s+/[fq]\b`),
|
||||
regexp.MustCompile(`\brmdir\s+/s\b`),
|
||||
// Match disk wiping commands (must be followed by space/args)
|
||||
regexp.MustCompile(
|
||||
`\b(format|mkfs|diskpart)\b\s`,
|
||||
),
|
||||
regexp.MustCompile(`\bdd\s+if=`),
|
||||
regexp.MustCompile(`>\s*/dev/sd[a-z]\b`), // Block writes to disk devices (but allow /dev/null)
|
||||
regexp.MustCompile(`\b(shutdown|reboot|poweroff)\b`),
|
||||
regexp.MustCompile(`:\(\)\s*\{.*\};\s*:`),
|
||||
regexp.MustCompile(`\$\([^)]+\)`),
|
||||
regexp.MustCompile(`\$\{[^}]+\}`),
|
||||
regexp.MustCompile("`[^`]+`"),
|
||||
regexp.MustCompile(`\|\s*sh\b`),
|
||||
regexp.MustCompile(`\|\s*bash\b`),
|
||||
regexp.MustCompile(`;\s*rm\s+-[rf]`),
|
||||
regexp.MustCompile(`&&\s*rm\s+-[rf]`),
|
||||
regexp.MustCompile(`\|\|\s*rm\s+-[rf]`),
|
||||
regexp.MustCompile(`>\s*/dev/null\s*>&?\s*\d?`),
|
||||
regexp.MustCompile(`<<\s*EOF`),
|
||||
regexp.MustCompile(`\$\(\s*cat\s+`),
|
||||
regexp.MustCompile(`\$\(\s*curl\s+`),
|
||||
regexp.MustCompile(`\$\(\s*wget\s+`),
|
||||
regexp.MustCompile(`\$\(\s*which\s+`),
|
||||
regexp.MustCompile(`\bsudo\b`),
|
||||
regexp.MustCompile(`\bchmod\s+[0-7]{3,4}\b`),
|
||||
regexp.MustCompile(`\bchown\b`),
|
||||
regexp.MustCompile(`\bpkill\b`),
|
||||
regexp.MustCompile(`\bkillall\b`),
|
||||
regexp.MustCompile(`\bkill\s+-[9]\b`),
|
||||
regexp.MustCompile(`\bcurl\b.*\|\s*(sh|bash)`),
|
||||
regexp.MustCompile(`\bwget\b.*\|\s*(sh|bash)`),
|
||||
regexp.MustCompile(`\bnpm\s+install\s+-g\b`),
|
||||
regexp.MustCompile(`\bpip\s+install\s+--user\b`),
|
||||
regexp.MustCompile(`\bapt\s+(install|remove|purge)\b`),
|
||||
regexp.MustCompile(`\byum\s+(install|remove)\b`),
|
||||
regexp.MustCompile(`\bdnf\s+(install|remove)\b`),
|
||||
regexp.MustCompile(`\bdocker\s+run\b`),
|
||||
regexp.MustCompile(`\bdocker\s+exec\b`),
|
||||
regexp.MustCompile(`\bgit\s+push\b`),
|
||||
regexp.MustCompile(`\bgit\s+force\b`),
|
||||
regexp.MustCompile(`\bssh\b.*@`),
|
||||
regexp.MustCompile(`\beval\b`),
|
||||
regexp.MustCompile(`\bsource\s+.*\.sh\b`),
|
||||
}
|
||||
|
||||
func NewExecTool(workingDir string, restrict bool) *ExecTool {
|
||||
// absolutePathPattern matches absolute file paths in commands (Unix and Windows).
|
||||
absolutePathPattern = regexp.MustCompile(`[A-Za-z]:\\[^\\\"']+|/[^\s\"']+`)
|
||||
)
|
||||
|
||||
func NewExecTool(workingDir string, restrict bool) (*ExecTool, error) {
|
||||
return NewExecToolWithConfig(workingDir, restrict, nil)
|
||||
}
|
||||
|
||||
func NewExecToolWithConfig(workingDir string, restrict bool, config *config.Config) *ExecTool {
|
||||
func NewExecToolWithConfig(workingDir string, restrict bool, config *config.Config) (*ExecTool, error) {
|
||||
denyPatterns := make([]*regexp.Regexp, 0)
|
||||
|
||||
if config != nil {
|
||||
@@ -86,8 +94,7 @@ func NewExecToolWithConfig(workingDir string, restrict bool, config *config.Conf
|
||||
for _, pattern := range execConfig.CustomDenyPatterns {
|
||||
re, err := regexp.Compile(pattern)
|
||||
if err != nil {
|
||||
fmt.Printf("Invalid custom deny pattern %q: %v\n", pattern, err)
|
||||
continue
|
||||
return nil, fmt.Errorf("invalid custom deny pattern %q: %w", pattern, err)
|
||||
}
|
||||
denyPatterns = append(denyPatterns, re)
|
||||
}
|
||||
@@ -106,7 +113,7 @@ func NewExecToolWithConfig(workingDir string, restrict bool, config *config.Conf
|
||||
denyPatterns: denyPatterns,
|
||||
allowPatterns: nil,
|
||||
restrictToWorkspace: restrict,
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (t *ExecTool) Name() string {
|
||||
@@ -288,8 +295,7 @@ func (t *ExecTool) guardCommand(command, cwd string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
pathPattern := regexp.MustCompile(`[A-Za-z]:\\[^\\\"']+|/[^\s\"']+`)
|
||||
matches := pathPattern.FindAllString(cmd, -1)
|
||||
matches := absolutePathPattern.FindAllString(cmd, -1)
|
||||
|
||||
for _, raw := range matches {
|
||||
p, err := filepath.Abs(raw)
|
||||
|
||||
+48
-11
@@ -11,7 +11,10 @@ import (
|
||||
|
||||
// TestShellTool_Success verifies successful command execution
|
||||
func TestShellTool_Success(t *testing.T) {
|
||||
tool := NewExecTool("", false)
|
||||
tool, err := NewExecTool("", false)
|
||||
if err != nil {
|
||||
t.Errorf("unable to configure exec tool: %s", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]any{
|
||||
@@ -38,7 +41,10 @@ func TestShellTool_Success(t *testing.T) {
|
||||
|
||||
// TestShellTool_Failure verifies failed command execution
|
||||
func TestShellTool_Failure(t *testing.T) {
|
||||
tool := NewExecTool("", false)
|
||||
tool, err := NewExecTool("", false)
|
||||
if err != nil {
|
||||
t.Errorf("unable to configure exec tool: %s", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]any{
|
||||
@@ -65,7 +71,11 @@ func TestShellTool_Failure(t *testing.T) {
|
||||
|
||||
// TestShellTool_Timeout verifies command timeout handling
|
||||
func TestShellTool_Timeout(t *testing.T) {
|
||||
tool := NewExecTool("", false)
|
||||
tool, err := NewExecTool("", false)
|
||||
if err != nil {
|
||||
t.Errorf("unable to configure exec tool: %s", err)
|
||||
}
|
||||
|
||||
tool.SetTimeout(100 * time.Millisecond)
|
||||
|
||||
ctx := context.Background()
|
||||
@@ -93,7 +103,10 @@ func TestShellTool_WorkingDir(t *testing.T) {
|
||||
testFile := filepath.Join(tmpDir, "test.txt")
|
||||
os.WriteFile(testFile, []byte("test content"), 0o644)
|
||||
|
||||
tool := NewExecTool("", false)
|
||||
tool, err := NewExecTool("", false)
|
||||
if err != nil {
|
||||
t.Errorf("unable to configure exec tool: %s", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]any{
|
||||
@@ -114,7 +127,10 @@ func TestShellTool_WorkingDir(t *testing.T) {
|
||||
|
||||
// TestShellTool_DangerousCommand verifies safety guard blocks dangerous commands
|
||||
func TestShellTool_DangerousCommand(t *testing.T) {
|
||||
tool := NewExecTool("", false)
|
||||
tool, err := NewExecTool("", false)
|
||||
if err != nil {
|
||||
t.Errorf("unable to configure exec tool: %s", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]any{
|
||||
@@ -135,7 +151,10 @@ func TestShellTool_DangerousCommand(t *testing.T) {
|
||||
|
||||
// TestShellTool_MissingCommand verifies error handling for missing command
|
||||
func TestShellTool_MissingCommand(t *testing.T) {
|
||||
tool := NewExecTool("", false)
|
||||
tool, err := NewExecTool("", false)
|
||||
if err != nil {
|
||||
t.Errorf("unable to configure exec tool: %s", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]any{}
|
||||
@@ -150,7 +169,10 @@ func TestShellTool_MissingCommand(t *testing.T) {
|
||||
|
||||
// TestShellTool_StderrCapture verifies stderr is captured and included
|
||||
func TestShellTool_StderrCapture(t *testing.T) {
|
||||
tool := NewExecTool("", false)
|
||||
tool, err := NewExecTool("", false)
|
||||
if err != nil {
|
||||
t.Errorf("unable to configure exec tool: %s", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]any{
|
||||
@@ -170,7 +192,10 @@ func TestShellTool_StderrCapture(t *testing.T) {
|
||||
|
||||
// TestShellTool_OutputTruncation verifies long output is truncated
|
||||
func TestShellTool_OutputTruncation(t *testing.T) {
|
||||
tool := NewExecTool("", false)
|
||||
tool, err := NewExecTool("", false)
|
||||
if err != nil {
|
||||
t.Errorf("unable to configure exec tool: %s", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
// Generate long output (>10000 chars)
|
||||
@@ -198,7 +223,11 @@ func TestShellTool_WorkingDir_OutsideWorkspace(t *testing.T) {
|
||||
t.Fatalf("failed to create outside dir: %v", err)
|
||||
}
|
||||
|
||||
tool := NewExecTool(workspace, true)
|
||||
tool, err := NewExecTool(workspace, true)
|
||||
if err != nil {
|
||||
t.Errorf("unable to configure exec tool: %s", err)
|
||||
}
|
||||
|
||||
result := tool.Execute(context.Background(), map[string]any{
|
||||
"command": "pwd",
|
||||
"working_dir": outsideDir,
|
||||
@@ -232,7 +261,11 @@ func TestShellTool_WorkingDir_SymlinkEscape(t *testing.T) {
|
||||
t.Skipf("symlinks not supported in this environment: %v", err)
|
||||
}
|
||||
|
||||
tool := NewExecTool(workspace, true)
|
||||
tool, err := NewExecTool(workspace, true)
|
||||
if err != nil {
|
||||
t.Errorf("unable to configure exec tool: %s", err)
|
||||
}
|
||||
|
||||
result := tool.Execute(context.Background(), map[string]any{
|
||||
"command": "cat secret.txt",
|
||||
"working_dir": link,
|
||||
@@ -249,7 +282,11 @@ func TestShellTool_WorkingDir_SymlinkEscape(t *testing.T) {
|
||||
// TestShellTool_RestrictToWorkspace verifies workspace restriction
|
||||
func TestShellTool_RestrictToWorkspace(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
tool := NewExecTool(tmpDir, false)
|
||||
tool, err := NewExecTool(tmpDir, false)
|
||||
if err != nil {
|
||||
t.Errorf("unable to configure exec tool: %s", err)
|
||||
}
|
||||
|
||||
tool.SetRestrictToWorkspace(true)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
@@ -22,7 +22,11 @@ func processExists(pid int) bool {
|
||||
}
|
||||
|
||||
func TestShellTool_TimeoutKillsChildProcess(t *testing.T) {
|
||||
tool := NewExecTool(t.TempDir(), false)
|
||||
tool, err := NewExecTool(t.TempDir(), false)
|
||||
if err != nil {
|
||||
t.Errorf("unable to configure exec tool: %s", err)
|
||||
}
|
||||
|
||||
tool.SetTimeout(500 * time.Millisecond)
|
||||
|
||||
args := map[string]any{
|
||||
|
||||
+58
-48
@@ -16,6 +16,14 @@ import (
|
||||
|
||||
const (
|
||||
userAgent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36"
|
||||
|
||||
// HTTP client timeouts for web tool providers.
|
||||
searchTimeout = 10 * time.Second // Brave, Tavily, DuckDuckGo
|
||||
perplexityTimeout = 30 * time.Second // Perplexity (LLM-based, slower)
|
||||
fetchTimeout = 60 * time.Second // WebFetchTool
|
||||
|
||||
defaultMaxChars = 50000
|
||||
maxRedirects = 5
|
||||
)
|
||||
|
||||
// Pre-compiled regexes for HTML text extraction
|
||||
@@ -75,6 +83,7 @@ type SearchProvider interface {
|
||||
type BraveSearchProvider struct {
|
||||
apiKey string
|
||||
proxy string
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
func (p *BraveSearchProvider) Search(ctx context.Context, query string, count int) (string, error) {
|
||||
@@ -89,11 +98,7 @@ func (p *BraveSearchProvider) Search(ctx context.Context, query string, count in
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("X-Subscription-Token", p.apiKey)
|
||||
|
||||
client, err := createHTTPClient(p.proxy, 10*time.Second)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create HTTP client: %w", err)
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
resp, err := p.client.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
@@ -144,6 +149,7 @@ type TavilySearchProvider struct {
|
||||
apiKey string
|
||||
baseURL string
|
||||
proxy string
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
func (p *TavilySearchProvider) Search(ctx context.Context, query string, count int) (string, error) {
|
||||
@@ -175,11 +181,7 @@ func (p *TavilySearchProvider) Search(ctx context.Context, query string, count i
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", userAgent)
|
||||
|
||||
client, err := createHTTPClient(p.proxy, 10*time.Second)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create HTTP client: %w", err)
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
resp, err := p.client.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
@@ -227,7 +229,8 @@ func (p *TavilySearchProvider) Search(ctx context.Context, query string, count i
|
||||
}
|
||||
|
||||
type DuckDuckGoSearchProvider struct {
|
||||
proxy string
|
||||
proxy string
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
func (p *DuckDuckGoSearchProvider) Search(ctx context.Context, query string, count int) (string, error) {
|
||||
@@ -240,11 +243,7 @@ func (p *DuckDuckGoSearchProvider) Search(ctx context.Context, query string, cou
|
||||
|
||||
req.Header.Set("User-Agent", userAgent)
|
||||
|
||||
client, err := createHTTPClient(p.proxy, 10*time.Second)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create HTTP client: %w", err)
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
resp, err := p.client.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
@@ -286,7 +285,7 @@ func (p *DuckDuckGoSearchProvider) extractResults(html string, count int, query
|
||||
|
||||
maxItems := min(len(matches), count)
|
||||
|
||||
for i := 0; i < maxItems; i++ {
|
||||
for i := range maxItems {
|
||||
urlStr := matches[i][1]
|
||||
title := stripTags(matches[i][2])
|
||||
title = strings.TrimSpace(title)
|
||||
@@ -294,9 +293,9 @@ func (p *DuckDuckGoSearchProvider) extractResults(html string, count int, query
|
||||
// URL decoding if needed
|
||||
if strings.Contains(urlStr, "uddg=") {
|
||||
if u, err := url.QueryUnescape(urlStr); err == nil {
|
||||
idx := strings.Index(u, "uddg=")
|
||||
if idx != -1 {
|
||||
urlStr = u[idx+5:]
|
||||
_, after, ok := strings.Cut(u, "uddg=")
|
||||
if ok {
|
||||
urlStr = after
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -323,6 +322,7 @@ func stripTags(content string) string {
|
||||
type PerplexitySearchProvider struct {
|
||||
apiKey string
|
||||
proxy string
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
func (p *PerplexitySearchProvider) Search(ctx context.Context, query string, count int) (string, error) {
|
||||
@@ -357,11 +357,7 @@ func (p *PerplexitySearchProvider) Search(ctx context.Context, query string, cou
|
||||
req.Header.Set("Authorization", "Bearer "+p.apiKey)
|
||||
req.Header.Set("User-Agent", userAgent)
|
||||
|
||||
client, err := createHTTPClient(p.proxy, 30*time.Second)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create HTTP client: %w", err)
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
resp, err := p.client.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
@@ -416,43 +412,60 @@ type WebSearchToolOptions struct {
|
||||
Proxy string
|
||||
}
|
||||
|
||||
func NewWebSearchTool(opts WebSearchToolOptions) *WebSearchTool {
|
||||
func NewWebSearchTool(opts WebSearchToolOptions) (*WebSearchTool, error) {
|
||||
var provider SearchProvider
|
||||
maxResults := 5
|
||||
|
||||
// Priority: Perplexity > Brave > Tavily > DuckDuckGo
|
||||
if opts.PerplexityEnabled && opts.PerplexityAPIKey != "" {
|
||||
provider = &PerplexitySearchProvider{apiKey: opts.PerplexityAPIKey, proxy: opts.Proxy}
|
||||
client, err := createHTTPClient(opts.Proxy, perplexityTimeout)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create HTTP client for Perplexity: %w", err)
|
||||
}
|
||||
provider = &PerplexitySearchProvider{apiKey: opts.PerplexityAPIKey, proxy: opts.Proxy, client: client}
|
||||
if opts.PerplexityMaxResults > 0 {
|
||||
maxResults = opts.PerplexityMaxResults
|
||||
}
|
||||
} else if opts.BraveEnabled && opts.BraveAPIKey != "" {
|
||||
provider = &BraveSearchProvider{apiKey: opts.BraveAPIKey, proxy: opts.Proxy}
|
||||
client, err := createHTTPClient(opts.Proxy, searchTimeout)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create HTTP client for Brave: %w", err)
|
||||
}
|
||||
provider = &BraveSearchProvider{apiKey: opts.BraveAPIKey, proxy: opts.Proxy, client: client}
|
||||
if opts.BraveMaxResults > 0 {
|
||||
maxResults = opts.BraveMaxResults
|
||||
}
|
||||
} else if opts.TavilyEnabled && opts.TavilyAPIKey != "" {
|
||||
client, err := createHTTPClient(opts.Proxy, searchTimeout)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create HTTP client for Tavily: %w", err)
|
||||
}
|
||||
provider = &TavilySearchProvider{
|
||||
apiKey: opts.TavilyAPIKey,
|
||||
baseURL: opts.TavilyBaseURL,
|
||||
proxy: opts.Proxy,
|
||||
client: client,
|
||||
}
|
||||
if opts.TavilyMaxResults > 0 {
|
||||
maxResults = opts.TavilyMaxResults
|
||||
}
|
||||
} else if opts.DuckDuckGoEnabled {
|
||||
provider = &DuckDuckGoSearchProvider{proxy: opts.Proxy}
|
||||
client, err := createHTTPClient(opts.Proxy, searchTimeout)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create HTTP client for DuckDuckGo: %w", err)
|
||||
}
|
||||
provider = &DuckDuckGoSearchProvider{proxy: opts.Proxy, client: client}
|
||||
if opts.DuckDuckGoMaxResults > 0 {
|
||||
maxResults = opts.DuckDuckGoMaxResults
|
||||
}
|
||||
} else {
|
||||
return nil
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return &WebSearchTool{
|
||||
provider: provider,
|
||||
maxResults: maxResults,
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (t *WebSearchTool) Name() string {
|
||||
@@ -527,7 +540,17 @@ func NewWebFetchTool(maxChars int, fetchLimitBytes int64) *WebFetchTool {
|
||||
|
||||
func NewWebFetchToolWithProxy(maxChars int, proxy string, fetchLimitBytes int64) *WebFetchTool {
|
||||
if maxChars <= 0 {
|
||||
maxChars = 50000
|
||||
maxChars = defaultMaxChars
|
||||
}
|
||||
client, err := createHTTPClient(proxy, fetchTimeout)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create HTTP client for web fetch: %w", err)
|
||||
}
|
||||
client.CheckRedirect = func(req *http.Request, via []*http.Request) error {
|
||||
if len(via) >= maxRedirects {
|
||||
return fmt.Errorf("stopped after %d redirects", maxRedirects)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if fetchLimitBytes <= 0 {
|
||||
fetchLimitBytes = 10 * 1024 * 1024 // Security Fallback
|
||||
@@ -598,20 +621,7 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]any) *ToolRe
|
||||
|
||||
req.Header.Set("User-Agent", userAgent)
|
||||
|
||||
client, err := createHTTPClient(t.proxy, 60*time.Second)
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("failed to create HTTP client: %v", err))
|
||||
}
|
||||
|
||||
// Configure redirect handling
|
||||
client.CheckRedirect = func(req *http.Request, via []*http.Request) error {
|
||||
if len(via) >= 5 {
|
||||
return fmt.Errorf("stopped after 5 redirects")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
resp, err := t.client.Do(req)
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("request failed: %v", err))
|
||||
}
|
||||
@@ -669,14 +679,14 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]any) *ToolRe
|
||||
resultJSON, _ := json.MarshalIndent(result, "", " ")
|
||||
|
||||
return &ToolResult{
|
||||
ForLLM: fmt.Sprintf(
|
||||
ForLLM: string(resultJSON),
|
||||
ForUser: fmt.Sprintf(
|
||||
"Fetched %d bytes from %s (extractor: %s, truncated: %v)",
|
||||
len(text),
|
||||
urlStr,
|
||||
extractor,
|
||||
truncated,
|
||||
),
|
||||
ForUser: string(resultJSON),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
+45
-24
@@ -36,14 +36,14 @@ func TestWebTool_WebFetch_Success(t *testing.T) {
|
||||
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
// ForUser should contain the fetched content
|
||||
if !strings.Contains(result.ForUser, "Test Page") {
|
||||
t.Errorf("Expected ForUser to contain 'Test Page', got: %s", result.ForUser)
|
||||
// ForLLM should contain the fetched content (full JSON result)
|
||||
if !strings.Contains(result.ForLLM, "Test Page") {
|
||||
t.Errorf("Expected ForLLM to contain 'Test Page', got: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
// ForLLM should contain summary
|
||||
if !strings.Contains(result.ForLLM, "bytes") && !strings.Contains(result.ForLLM, "extractor") {
|
||||
t.Errorf("Expected ForLLM to contain summary, got: %s", result.ForLLM)
|
||||
// ForUser should contain summary
|
||||
if !strings.Contains(result.ForUser, "bytes") && !strings.Contains(result.ForUser, "extractor") {
|
||||
t.Errorf("Expected ForUser to contain summary, got: %s", result.ForUser)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -72,9 +72,9 @@ func TestWebTool_WebFetch_JSON(t *testing.T) {
|
||||
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
// ForUser should contain formatted JSON
|
||||
if !strings.Contains(result.ForUser, "key") && !strings.Contains(result.ForUser, "value") {
|
||||
t.Errorf("Expected ForUser to contain JSON data, got: %s", result.ForUser)
|
||||
// ForLLM should contain formatted JSON
|
||||
if !strings.Contains(result.ForLLM, "key") && !strings.Contains(result.ForLLM, "value") {
|
||||
t.Errorf("Expected ForLLM to contain JSON data, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -163,9 +163,9 @@ func TestWebTool_WebFetch_Truncation(t *testing.T) {
|
||||
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
// ForUser should contain truncated content (not the full 20000 chars)
|
||||
// ForLLM should contain truncated content (not the full 20000 chars)
|
||||
resultMap := make(map[string]any)
|
||||
json.Unmarshal([]byte(result.ForUser), &resultMap)
|
||||
json.Unmarshal([]byte(result.ForLLM), &resultMap)
|
||||
if text, ok := resultMap["text"].(string); ok {
|
||||
if len(text) > 1100 { // Allow some margin
|
||||
t.Errorf("Expected content to be truncated to ~1000 chars, got: %d", len(text))
|
||||
@@ -220,13 +220,19 @@ func TestWebFetchTool_PayloadTooLarge(t *testing.T) {
|
||||
|
||||
// TestWebTool_WebSearch_NoApiKey verifies that no tool is created when API key is missing
|
||||
func TestWebTool_WebSearch_NoApiKey(t *testing.T) {
|
||||
tool := NewWebSearchTool(WebSearchToolOptions{BraveEnabled: true, BraveAPIKey: ""})
|
||||
tool, err := NewWebSearchTool(WebSearchToolOptions{BraveEnabled: true, BraveAPIKey: ""})
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
if tool != nil {
|
||||
t.Errorf("Expected nil tool when Brave API key is empty")
|
||||
}
|
||||
|
||||
// Also nil when nothing is enabled
|
||||
tool = NewWebSearchTool(WebSearchToolOptions{})
|
||||
tool, err = NewWebSearchTool(WebSearchToolOptions{})
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
if tool != nil {
|
||||
t.Errorf("Expected nil tool when no provider is enabled")
|
||||
}
|
||||
@@ -234,7 +240,10 @@ func TestWebTool_WebSearch_NoApiKey(t *testing.T) {
|
||||
|
||||
// TestWebTool_WebSearch_MissingQuery verifies error handling for missing query
|
||||
func TestWebTool_WebSearch_MissingQuery(t *testing.T) {
|
||||
tool := NewWebSearchTool(WebSearchToolOptions{BraveEnabled: true, BraveAPIKey: "test-key", BraveMaxResults: 5})
|
||||
tool, err := NewWebSearchTool(WebSearchToolOptions{BraveEnabled: true, BraveAPIKey: "test-key", BraveMaxResults: 5})
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
ctx := context.Background()
|
||||
args := map[string]any{}
|
||||
|
||||
@@ -272,14 +281,14 @@ func TestWebTool_WebFetch_HTMLExtraction(t *testing.T) {
|
||||
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
// ForUser should contain extracted text (without script/style tags)
|
||||
if !strings.Contains(result.ForUser, "Title") && !strings.Contains(result.ForUser, "Content") {
|
||||
t.Errorf("Expected ForUser to contain extracted text, got: %s", result.ForUser)
|
||||
// ForLLM should contain extracted text (without script/style tags)
|
||||
if !strings.Contains(result.ForLLM, "Title") && !strings.Contains(result.ForLLM, "Content") {
|
||||
t.Errorf("Expected ForLLM to contain extracted text, got: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
// Should NOT contain script or style tags
|
||||
if strings.Contains(result.ForUser, "<script>") || strings.Contains(result.ForUser, "<style>") {
|
||||
t.Errorf("Expected script/style tags to be removed, got: %s", result.ForUser)
|
||||
// Should NOT contain script or style tags in ForLLM
|
||||
if strings.Contains(result.ForLLM, "<script>") || strings.Contains(result.ForLLM, "<style>") {
|
||||
t.Errorf("Expected script/style tags to be removed, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -498,12 +507,15 @@ func TestNewWebFetchToolWithProxy(t *testing.T) {
|
||||
|
||||
func TestNewWebSearchTool_PropagatesProxy(t *testing.T) {
|
||||
t.Run("perplexity", func(t *testing.T) {
|
||||
tool := NewWebSearchTool(WebSearchToolOptions{
|
||||
tool, err := NewWebSearchTool(WebSearchToolOptions{
|
||||
PerplexityEnabled: true,
|
||||
PerplexityAPIKey: "k",
|
||||
PerplexityMaxResults: 3,
|
||||
Proxy: "http://127.0.0.1:7890",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewWebSearchTool() error: %v", err)
|
||||
}
|
||||
p, ok := tool.provider.(*PerplexitySearchProvider)
|
||||
if !ok {
|
||||
t.Fatalf("provider type = %T, want *PerplexitySearchProvider", tool.provider)
|
||||
@@ -514,12 +526,15 @@ func TestNewWebSearchTool_PropagatesProxy(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("brave", func(t *testing.T) {
|
||||
tool := NewWebSearchTool(WebSearchToolOptions{
|
||||
tool, err := NewWebSearchTool(WebSearchToolOptions{
|
||||
BraveEnabled: true,
|
||||
BraveAPIKey: "k",
|
||||
BraveMaxResults: 3,
|
||||
Proxy: "http://127.0.0.1:7890",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewWebSearchTool() error: %v", err)
|
||||
}
|
||||
p, ok := tool.provider.(*BraveSearchProvider)
|
||||
if !ok {
|
||||
t.Fatalf("provider type = %T, want *BraveSearchProvider", tool.provider)
|
||||
@@ -530,11 +545,14 @@ func TestNewWebSearchTool_PropagatesProxy(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("duckduckgo", func(t *testing.T) {
|
||||
tool := NewWebSearchTool(WebSearchToolOptions{
|
||||
tool, err := NewWebSearchTool(WebSearchToolOptions{
|
||||
DuckDuckGoEnabled: true,
|
||||
DuckDuckGoMaxResults: 3,
|
||||
Proxy: "http://127.0.0.1:7890",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewWebSearchTool() error: %v", err)
|
||||
}
|
||||
p, ok := tool.provider.(*DuckDuckGoSearchProvider)
|
||||
if !ok {
|
||||
t.Fatalf("provider type = %T, want *DuckDuckGoSearchProvider", tool.provider)
|
||||
@@ -586,12 +604,15 @@ func TestWebTool_TavilySearch_Success(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
tool := NewWebSearchTool(WebSearchToolOptions{
|
||||
tool, err := NewWebSearchTool(WebSearchToolOptions{
|
||||
TavilyEnabled: true,
|
||||
TavilyAPIKey: "test-key",
|
||||
TavilyBaseURL: server.URL,
|
||||
TavilyMaxResults: 5,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewWebSearchTool() error: %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]any{
|
||||
|
||||
Reference in New Issue
Block a user