diff --git a/cmd/picoclaw/cmd_gateway.go b/cmd/picoclaw/cmd_gateway.go index cf7f3563a..3010c1451 100644 --- a/cmd/picoclaw/cmd_gateway.go +++ b/cmd/picoclaw/cmd_gateway.go @@ -211,6 +211,9 @@ func gatewayCmd() { <-sigChan fmt.Println("\nShutting down...") + if cp, ok := provider.(providers.StatefulProvider); ok { + cp.Close() + } cancel() healthServer.Stop(context.Background()) deviceService.Stop() diff --git a/config/config.example.json b/config/config.example.json index e8c6b3d3f..9575039f8 100644 --- a/config/config.example.json +++ b/config/config.example.json @@ -217,7 +217,8 @@ "enabled": false, "api_key": "pplx-xxx", "max_results": 5 - } + }, + "proxy": "" }, "cron": { "exec_timeout_minutes": 5 diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index 9a2bb1198..dbc4a9b87 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -106,10 +106,11 @@ func registerSharedTools( PerplexityAPIKey: cfg.Tools.Web.Perplexity.APIKey, PerplexityMaxResults: cfg.Tools.Web.Perplexity.MaxResults, PerplexityEnabled: cfg.Tools.Web.Perplexity.Enabled, + Proxy: cfg.Tools.Web.Proxy, }); searchTool != nil { agent.Tools.Register(searchTool) } - agent.Tools.Register(tools.NewWebFetchTool(50000)) + agent.Tools.Register(tools.NewWebFetchToolWithProxy(50000, cfg.Tools.Web.Proxy)) // Hardware tools (I2C, SPI) - Linux only, returns error on other platforms agent.Tools.Register(tools.NewI2CTool()) diff --git a/pkg/config/config.go b/pkg/config/config.go index 978218251..fa9ec93da 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -453,6 +453,9 @@ type WebToolsConfig struct { Tavily TavilyConfig `json:"tavily"` DuckDuckGo DuckDuckGoConfig `json:"duckduckgo"` Perplexity PerplexityConfig `json:"perplexity"` + // Proxy is an optional proxy URL for web tools (http/https/socks5/socks5h). + // For authenticated proxies, prefer HTTP_PROXY/HTTPS_PROXY env vars instead of embedding credentials in config. + Proxy string `json:"proxy,omitempty" env:"PICOCLAW_TOOLS_WEB_PROXY"` } type CronToolsConfig struct { diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go index f88c0269c..223ac798d 100644 --- a/pkg/config/config_test.go +++ b/pkg/config/config_test.go @@ -392,3 +392,24 @@ func TestLoadConfig_OpenAIWebSearchCanBeDisabled(t *testing.T) { t.Fatal("OpenAI codex web search should be false when disabled in config file") } } + +func TestLoadConfig_WebToolsProxy(t *testing.T) { + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "config.json") + configJSON := `{ + "agents": {"defaults":{"workspace":"./workspace","model":"gpt4","max_tokens":8192,"max_tool_iterations":20}}, + "model_list": [{"model_name":"gpt4","model":"openai/gpt-5.2","api_key":"x"}], + "tools": {"web":{"proxy":"http://127.0.0.1:7890"}} +}` + if err := os.WriteFile(configPath, []byte(configJSON), 0o600); err != nil { + t.Fatalf("os.WriteFile() error: %v", err) + } + + cfg, err := LoadConfig(configPath) + if err != nil { + t.Fatalf("LoadConfig() error: %v", err) + } + if cfg.Tools.Web.Proxy != "http://127.0.0.1:7890" { + t.Fatalf("Tools.Web.Proxy = %q, want %q", cfg.Tools.Web.Proxy, "http://127.0.0.1:7890") + } +} diff --git a/pkg/providers/github_copilot_provider.go b/pkg/providers/github_copilot_provider.go index 6124881f7..9210021e1 100644 --- a/pkg/providers/github_copilot_provider.go +++ b/pkg/providers/github_copilot_provider.go @@ -4,60 +4,84 @@ import ( "context" "encoding/json" "fmt" + "sync" copilot "github.com/github/copilot-sdk/go" ) type GitHubCopilotProvider struct { uri string - connectMode string // `stdio` or `grpc`` + connectMode string // "stdio" or "grpc" + client *copilot.Client session *copilot.Session + + mu sync.Mutex } func NewGitHubCopilotProvider(uri string, connectMode string, model string) (*GitHubCopilotProvider, error) { - var session *copilot.Session if connectMode == "" { connectMode = "grpc" } - switch connectMode { + switch connectMode { case "stdio": - // todo + // TODO: + return nil, fmt.Errorf("stdio mode not implemented") case "grpc": client := copilot.NewClient(&copilot.ClientOptions{ CLIUrl: uri, }) if err := client.Start(context.Background()); err != nil { return nil, fmt.Errorf( - "Can't connect to Github Copilot, https://github.com/github/copilot-sdk/blob/main/docs/getting-started.md#connecting-to-an-external-cli-server for details", + "can't connect to Github Copilot: %w; `https://github.com/github/copilot-sdk/blob/main/docs/getting-started.md#connecting-to-an-external-cli-server` for details", + err, ) } - defer client.Stop() - session, _ = client.CreateSession(context.Background(), &copilot.SessionConfig{ + + session, err := client.CreateSession(context.Background(), &copilot.SessionConfig{ Model: model, Hooks: &copilot.SessionHooks{}, }) + if err != nil { + client.Stop() + return nil, fmt.Errorf("create session failed: %w", err) + } + + return &GitHubCopilotProvider{ + uri: uri, + connectMode: connectMode, + client: client, + session: session, + }, nil + default: + return nil, fmt.Errorf("unknown connect mode: %s", connectMode) + } +} + +func (p *GitHubCopilotProvider) Close() { + p.mu.Lock() + defer p.mu.Unlock() + if p.client != nil { + p.client.Stop() + p.client = nil + p.session = nil } - - return &GitHubCopilotProvider{ - uri: uri, - connectMode: connectMode, - session: session, - }, nil } -// Chat sends a chat request to GitHub Copilot func (p *GitHubCopilotProvider) Chat( - ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]any, + ctx context.Context, + messages []Message, + tools []ToolDefinition, + model string, + options map[string]any, ) (*LLMResponse, error) { type tempMessage struct { Role string `json:"role"` Content string `json:"content"` } out := make([]tempMessage, 0, len(messages)) - for _, msg := range messages { out = append(out, tempMessage{ Role: msg.Role, @@ -65,12 +89,30 @@ func (p *GitHubCopilotProvider) Chat( }) } - fullcontent, _ := json.Marshal(out) + fullcontent, err := json.Marshal(out) + if err != nil { + return nil, fmt.Errorf("marshal messages: %w", err) + } + p.mu.Lock() + session := p.session + p.mu.Unlock() - content, _ := p.session.Send(ctx, copilot.MessageOptions{ + if session == nil { + return nil, fmt.Errorf("provider closed") + } + + resp, err := session.SendAndWait(ctx, copilot.MessageOptions{ Prompt: string(fullcontent), }) + if resp == nil { + return nil, fmt.Errorf("empty response from copilot") + } + if resp.Data.Content == nil { + return nil, fmt.Errorf("no content in copilot response") + } + content := *resp.Data.Content + return &LLMResponse{ FinishReason: "stop", Content: content, diff --git a/pkg/providers/types.go b/pkg/providers/types.go index f711e7803..b2dda04a5 100644 --- a/pkg/providers/types.go +++ b/pkg/providers/types.go @@ -30,6 +30,11 @@ type LLMProvider interface { GetDefaultModel() string } +type StatefulProvider interface { + LLMProvider + Close() +} + // FailoverReason classifies why an LLM request failed for fallback decisions. type FailoverReason string diff --git a/pkg/tools/web.go b/pkg/tools/web.go index 452e95e0f..968579dea 100644 --- a/pkg/tools/web.go +++ b/pkg/tools/web.go @@ -17,12 +17,50 @@ const ( userAgent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36" ) +// createHTTPClient creates an HTTP client with optional proxy support +func createHTTPClient(proxyURL string, timeout time.Duration) (*http.Client, error) { + client := &http.Client{ + Timeout: timeout, + Transport: &http.Transport{ + MaxIdleConns: 10, + IdleConnTimeout: 30 * time.Second, + DisableCompression: false, + TLSHandshakeTimeout: 15 * time.Second, + }, + } + + if proxyURL != "" { + proxy, err := url.Parse(proxyURL) + if err != nil { + return nil, fmt.Errorf("invalid proxy URL: %w", err) + } + scheme := strings.ToLower(proxy.Scheme) + switch scheme { + case "http", "https", "socks5", "socks5h": + default: + return nil, fmt.Errorf( + "unsupported proxy scheme %q (supported: http, https, socks5, socks5h)", + proxy.Scheme, + ) + } + if proxy.Host == "" { + return nil, fmt.Errorf("invalid proxy URL: missing host") + } + client.Transport.(*http.Transport).Proxy = http.ProxyURL(proxy) + } else { + client.Transport.(*http.Transport).Proxy = http.ProxyFromEnvironment + } + + return client, nil +} + type SearchProvider interface { Search(ctx context.Context, query string, count int) (string, error) } type BraveSearchProvider struct { apiKey string + proxy string } func (p *BraveSearchProvider) Search(ctx context.Context, query string, count int) (string, error) { @@ -37,7 +75,10 @@ func (p *BraveSearchProvider) Search(ctx context.Context, query string, count in req.Header.Set("Accept", "application/json") req.Header.Set("X-Subscription-Token", p.apiKey) - client := &http.Client{Timeout: 10 * time.Second} + client, err := createHTTPClient(p.proxy, 10*time.Second) + if err != nil { + return "", fmt.Errorf("failed to create HTTP client: %w", err) + } resp, err := client.Do(req) if err != nil { return "", fmt.Errorf("request failed: %w", err) @@ -167,7 +208,9 @@ func (p *TavilySearchProvider) Search(ctx context.Context, query string, count i return strings.Join(lines, "\n"), nil } -type DuckDuckGoSearchProvider struct{} +type DuckDuckGoSearchProvider struct { + proxy string +} func (p *DuckDuckGoSearchProvider) Search(ctx context.Context, query string, count int) (string, error) { searchURL := fmt.Sprintf("https://html.duckduckgo.com/html/?q=%s", url.QueryEscape(query)) @@ -179,7 +222,10 @@ func (p *DuckDuckGoSearchProvider) Search(ctx context.Context, query string, cou req.Header.Set("User-Agent", userAgent) - client := &http.Client{Timeout: 10 * time.Second} + client, err := createHTTPClient(p.proxy, 10*time.Second) + if err != nil { + return "", fmt.Errorf("failed to create HTTP client: %w", err) + } resp, err := client.Do(req) if err != nil { return "", fmt.Errorf("request failed: %w", err) @@ -261,6 +307,7 @@ func stripTags(content string) string { type PerplexitySearchProvider struct { apiKey string + proxy string } func (p *PerplexitySearchProvider) Search(ctx context.Context, query string, count int) (string, error) { @@ -295,7 +342,10 @@ func (p *PerplexitySearchProvider) Search(ctx context.Context, query string, cou req.Header.Set("Authorization", "Bearer "+p.apiKey) req.Header.Set("User-Agent", userAgent) - client := &http.Client{Timeout: 30 * time.Second} + client, err := createHTTPClient(p.proxy, 30*time.Second) + if err != nil { + return "", fmt.Errorf("failed to create HTTP client: %w", err) + } resp, err := client.Do(req) if err != nil { return "", fmt.Errorf("request failed: %w", err) @@ -348,6 +398,7 @@ type WebSearchToolOptions struct { PerplexityAPIKey string PerplexityMaxResults int PerplexityEnabled bool + Proxy string } func NewWebSearchTool(opts WebSearchToolOptions) *WebSearchTool { @@ -356,12 +407,12 @@ func NewWebSearchTool(opts WebSearchToolOptions) *WebSearchTool { // Priority: Perplexity > Brave > Tavily > DuckDuckGo if opts.PerplexityEnabled && opts.PerplexityAPIKey != "" { - provider = &PerplexitySearchProvider{apiKey: opts.PerplexityAPIKey} + provider = &PerplexitySearchProvider{apiKey: opts.PerplexityAPIKey, proxy: opts.Proxy} if opts.PerplexityMaxResults > 0 { maxResults = opts.PerplexityMaxResults } } else if opts.BraveEnabled && opts.BraveAPIKey != "" { - provider = &BraveSearchProvider{apiKey: opts.BraveAPIKey} + provider = &BraveSearchProvider{apiKey: opts.BraveAPIKey, proxy: opts.Proxy} if opts.BraveMaxResults > 0 { maxResults = opts.BraveMaxResults } @@ -374,7 +425,7 @@ func NewWebSearchTool(opts WebSearchToolOptions) *WebSearchTool { maxResults = opts.TavilyMaxResults } } else if opts.DuckDuckGoEnabled { - provider = &DuckDuckGoSearchProvider{} + provider = &DuckDuckGoSearchProvider{proxy: opts.Proxy} if opts.DuckDuckGoMaxResults > 0 { maxResults = opts.DuckDuckGoMaxResults } @@ -441,6 +492,7 @@ func (t *WebSearchTool) Execute(ctx context.Context, args map[string]any) *ToolR type WebFetchTool struct { maxChars int + proxy string } func NewWebFetchTool(maxChars int) *WebFetchTool { @@ -452,6 +504,16 @@ func NewWebFetchTool(maxChars int) *WebFetchTool { } } +func NewWebFetchToolWithProxy(maxChars int, proxy string) *WebFetchTool { + if maxChars <= 0 { + maxChars = 50000 + } + return &WebFetchTool{ + maxChars: maxChars, + proxy: proxy, + } +} + func (t *WebFetchTool) Name() string { return "web_fetch" } @@ -511,20 +573,17 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]any) *ToolRe req.Header.Set("User-Agent", userAgent) - client := &http.Client{ - Timeout: 60 * time.Second, - Transport: &http.Transport{ - MaxIdleConns: 10, - IdleConnTimeout: 30 * time.Second, - DisableCompression: false, - TLSHandshakeTimeout: 15 * time.Second, - }, - CheckRedirect: func(req *http.Request, via []*http.Request) error { - if len(via) >= 5 { - return fmt.Errorf("stopped after 5 redirects") - } - return nil - }, + client, err := createHTTPClient(t.proxy, 60*time.Second) + if err != nil { + return ErrorResult(fmt.Sprintf("failed to create HTTP client: %v", err)) + } + + // Configure redirect handling + client.CheckRedirect = func(req *http.Request, via []*http.Request) error { + if len(via) >= 5 { + return fmt.Errorf("stopped after 5 redirects") + } + return nil } resp, err := client.Do(req) diff --git a/pkg/tools/web_test.go b/pkg/tools/web_test.go index 75e0d8d16..2cd79eb24 100644 --- a/pkg/tools/web_test.go +++ b/pkg/tools/web_test.go @@ -7,6 +7,7 @@ import ( "net/http/httptest" "strings" "testing" + "time" ) // TestWebTool_WebFetch_Success verifies successful URL fetching @@ -334,6 +335,172 @@ func TestWebTool_WebFetch_MissingDomain(t *testing.T) { } } +func TestCreateHTTPClient_ProxyConfigured(t *testing.T) { + client, err := createHTTPClient("http://127.0.0.1:7890", 12*time.Second) + if err != nil { + t.Fatalf("createHTTPClient() error: %v", err) + } + if client.Timeout != 12*time.Second { + t.Fatalf("client.Timeout = %v, want %v", client.Timeout, 12*time.Second) + } + + tr, ok := client.Transport.(*http.Transport) + if !ok { + t.Fatalf("client.Transport type = %T, want *http.Transport", client.Transport) + } + if tr.Proxy == nil { + t.Fatal("transport.Proxy is nil, want non-nil") + } + + req, err := http.NewRequest("GET", "https://example.com", nil) + if err != nil { + t.Fatalf("http.NewRequest() error: %v", err) + } + proxyURL, err := tr.Proxy(req) + if err != nil { + t.Fatalf("transport.Proxy(req) error: %v", err) + } + if proxyURL == nil || proxyURL.String() != "http://127.0.0.1:7890" { + t.Fatalf("proxy URL = %v, want %q", proxyURL, "http://127.0.0.1:7890") + } +} + +func TestCreateHTTPClient_InvalidProxy(t *testing.T) { + _, err := createHTTPClient("://bad-proxy", 10*time.Second) + if err == nil { + t.Fatal("createHTTPClient() expected error for invalid proxy URL, got nil") + } +} + +func TestCreateHTTPClient_Socks5ProxyConfigured(t *testing.T) { + client, err := createHTTPClient("socks5://127.0.0.1:1080", 8*time.Second) + if err != nil { + t.Fatalf("createHTTPClient() error: %v", err) + } + + tr, ok := client.Transport.(*http.Transport) + if !ok { + t.Fatalf("client.Transport type = %T, want *http.Transport", client.Transport) + } + req, err := http.NewRequest("GET", "https://example.com", nil) + if err != nil { + t.Fatalf("http.NewRequest() error: %v", err) + } + proxyURL, err := tr.Proxy(req) + if err != nil { + t.Fatalf("transport.Proxy(req) error: %v", err) + } + if proxyURL == nil || proxyURL.String() != "socks5://127.0.0.1:1080" { + t.Fatalf("proxy URL = %v, want %q", proxyURL, "socks5://127.0.0.1:1080") + } +} + +func TestCreateHTTPClient_UnsupportedProxyScheme(t *testing.T) { + _, err := createHTTPClient("ftp://127.0.0.1:21", 10*time.Second) + if err == nil { + t.Fatal("createHTTPClient() expected error for unsupported scheme, got nil") + } + if !strings.Contains(err.Error(), "unsupported proxy scheme") { + t.Fatalf("error = %q, want to contain %q", err.Error(), "unsupported proxy scheme") + } +} + +func TestCreateHTTPClient_ProxyFromEnvironmentWhenConfigEmpty(t *testing.T) { + t.Setenv("HTTP_PROXY", "http://127.0.0.1:8888") + t.Setenv("http_proxy", "http://127.0.0.1:8888") + t.Setenv("HTTPS_PROXY", "http://127.0.0.1:8888") + t.Setenv("https_proxy", "http://127.0.0.1:8888") + t.Setenv("ALL_PROXY", "") + t.Setenv("all_proxy", "") + t.Setenv("NO_PROXY", "") + t.Setenv("no_proxy", "") + + client, err := createHTTPClient("", 10*time.Second) + if err != nil { + t.Fatalf("createHTTPClient() error: %v", err) + } + + tr, ok := client.Transport.(*http.Transport) + if !ok { + t.Fatalf("client.Transport type = %T, want *http.Transport", client.Transport) + } + if tr.Proxy == nil { + t.Fatal("transport.Proxy is nil, want proxy function from environment") + } + + req, err := http.NewRequest("GET", "https://example.com", nil) + if err != nil { + t.Fatalf("http.NewRequest() error: %v", err) + } + if _, err := tr.Proxy(req); err != nil { + t.Fatalf("transport.Proxy(req) error: %v", err) + } +} + +func TestNewWebFetchToolWithProxy(t *testing.T) { + tool := NewWebFetchToolWithProxy(1024, "http://127.0.0.1:7890") + if tool.maxChars != 1024 { + t.Fatalf("maxChars = %d, want %d", tool.maxChars, 1024) + } + if tool.proxy != "http://127.0.0.1:7890" { + t.Fatalf("proxy = %q, want %q", tool.proxy, "http://127.0.0.1:7890") + } + + tool = NewWebFetchToolWithProxy(0, "http://127.0.0.1:7890") + if tool.maxChars != 50000 { + t.Fatalf("default maxChars = %d, want %d", tool.maxChars, 50000) + } +} + +func TestNewWebSearchTool_PropagatesProxy(t *testing.T) { + t.Run("perplexity", func(t *testing.T) { + tool := NewWebSearchTool(WebSearchToolOptions{ + PerplexityEnabled: true, + PerplexityAPIKey: "k", + PerplexityMaxResults: 3, + Proxy: "http://127.0.0.1:7890", + }) + p, ok := tool.provider.(*PerplexitySearchProvider) + if !ok { + t.Fatalf("provider type = %T, want *PerplexitySearchProvider", tool.provider) + } + if p.proxy != "http://127.0.0.1:7890" { + t.Fatalf("provider proxy = %q, want %q", p.proxy, "http://127.0.0.1:7890") + } + }) + + t.Run("brave", func(t *testing.T) { + tool := NewWebSearchTool(WebSearchToolOptions{ + BraveEnabled: true, + BraveAPIKey: "k", + BraveMaxResults: 3, + Proxy: "http://127.0.0.1:7890", + }) + p, ok := tool.provider.(*BraveSearchProvider) + if !ok { + t.Fatalf("provider type = %T, want *BraveSearchProvider", tool.provider) + } + if p.proxy != "http://127.0.0.1:7890" { + t.Fatalf("provider proxy = %q, want %q", p.proxy, "http://127.0.0.1:7890") + } + }) + + t.Run("duckduckgo", func(t *testing.T) { + tool := NewWebSearchTool(WebSearchToolOptions{ + DuckDuckGoEnabled: true, + DuckDuckGoMaxResults: 3, + Proxy: "http://127.0.0.1:7890", + }) + p, ok := tool.provider.(*DuckDuckGoSearchProvider) + if !ok { + t.Fatalf("provider type = %T, want *DuckDuckGoSearchProvider", tool.provider) + } + if p.proxy != "http://127.0.0.1:7890" { + t.Fatalf("provider proxy = %q, want %q", p.proxy, "http://127.0.0.1:7890") + } + }) +} + // TestWebTool_TavilySearch_Success verifies successful Tavily search func TestWebTool_TavilySearch_Success(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {