diff --git a/.goreleaser.yaml b/.goreleaser.yaml index 23e422ab1..af26509e6 100644 --- a/.goreleaser.yaml +++ b/.goreleaser.yaml @@ -5,7 +5,7 @@ version: 2 before: hooks: - go mod tidy - - go generate ./cmd/picoclaw + - go generate ./cmd/picoclaw/... builds: - id: picoclaw @@ -73,7 +73,6 @@ nfpms: package_name: picoclaw file_name_template: >- {{ .PackageName }}_ - {{- .Version }}_ {{- if eq .Arch "amd64" }}x86_64 {{- else if eq .Arch "arm64" }}aarch64 {{- else if eq .Arch "arm" }}armv{{ .Arm }} diff --git a/Makefile b/Makefile index c59c414f3..a14723616 100644 --- a/Makefile +++ b/Makefile @@ -44,6 +44,8 @@ ifeq ($(UNAME_S),Linux) ARCH=amd64 else ifeq ($(UNAME_M),aarch64) ARCH=arm64 + else ifeq ($(UNAME_M),armv81) + ARCH=arm64 else ifeq ($(UNAME_M),loongarch64) ARCH=loong64 else ifeq ($(UNAME_M),riscv64) diff --git a/README.fr.md b/README.fr.md index f59807739..f1d4f848e 100644 --- a/README.fr.md +++ b/README.fr.md @@ -221,6 +221,7 @@ picoclaw onboard "model_name": "gpt4", "model": "openai/gpt-5.2", "api_key": "sk-your-openai-key", + "request_timeout": 300, "api_base": "https://api.openai.com/v1" } ], @@ -252,6 +253,9 @@ picoclaw onboard } ``` +> **Nouveau** : Le format de configuration `model_list` permet d'ajouter des fournisseurs sans modifier le code. Voir [Configuration de Modèle](#configuration-de-modèle-model_list) pour plus de détails. +> `request_timeout` est optionnel et s'exprime en secondes. S'il est omis ou défini à `<= 0`, PicoClaw utilise le délai d'expiration par défaut (120s). + **3. Obtenir des Clés API** * **Fournisseur LLM** : [OpenRouter](https://openrouter.ai/keys) · [Zhipu](https://open.bigmodel.cn/usercenter/proj-mgmt/apikeys) · [Anthropic](https://console.anthropic.com) · [OpenAI](https://platform.openai.com) · [Gemini](https://aistudio.google.com/api-keys) @@ -979,6 +983,17 @@ Cette conception permet également le **support multi-agent** avec une sélectio ``` > Exécutez `picoclaw auth login --provider anthropic` pour configurer les identifiants OAuth. +**Proxy/API personnalisée** +```json +{ + "model_name": "my-custom-model", + "model": "openai/custom-model", + "api_base": "https://my-proxy.com/v1", + "api_key": "sk-...", + "request_timeout": 300 +} +``` + #### Équilibrage de Charge Configurez plusieurs points de terminaison pour le même nom de modèle—PicoClaw utilisera automatiquement le round-robin entre eux : diff --git a/README.ja.md b/README.ja.md index 5a7bb8542..48fb89fe3 100644 --- a/README.ja.md +++ b/README.ja.md @@ -183,6 +183,7 @@ picoclaw onboard "model_name": "gpt4", "model": "openai/gpt-5.2", "api_key": "sk-your-openai-key", + "request_timeout": 300, "api_base": "https://api.openai.com/v1" } ], @@ -221,6 +222,9 @@ picoclaw onboard } ``` +> **新機能**: `model_list` 形式により、プロバイダーをコード変更なしで追加できます。詳細は [モデル設定](#モデル設定-model_list) を参照してください。 +> `request_timeout` は任意の秒単位設定です。省略または `<= 0` の場合、PicoClaw はデフォルトのタイムアウト(120秒)を使用します。 + **3. API キーの取得** - **LLM プロバイダー**: [OpenRouter](https://openrouter.ai/keys) · [Zhipu](https://open.bigmodel.cn/usercenter/proj-mgmt/apikeys) · [Anthropic](https://console.anthropic.com) · [OpenAI](https://platform.openai.com) · [Gemini](https://aistudio.google.com/api-keys) @@ -918,6 +922,17 @@ HEARTBEAT_OK 応答 ユーザーが直接結果を受け取る ``` > OAuth認証を設定するには、`picoclaw auth login --provider anthropic` を実行してください。 +**カスタムプロキシ/API** +```json +{ + "model_name": "my-custom-model", + "model": "openai/custom-model", + "api_base": "https://my-proxy.com/v1", + "api_key": "sk-...", + "request_timeout": 300 +} +``` + #### ロードバランシング 同じモデル名で複数のエンドポイントを設定すると、PicoClaw が自動的にラウンドロビンで分散します: diff --git a/README.md b/README.md index aa7b0719a..72a933b6f 100644 --- a/README.md +++ b/README.md @@ -232,7 +232,8 @@ picoclaw onboard { "model_name": "gpt4", "model": "openai/gpt-5.2", - "api_key": "your-api-key" + "api_key": "your-api-key", + "request_timeout": 300 }, { "model_name": "claude-sonnet-4.6", @@ -262,6 +263,7 @@ picoclaw onboard ``` > **New**: The `model_list` configuration format allows zero-code provider addition. See [Model Configuration](#model-configuration-model_list) for details. +> `request_timeout` is optional and uses seconds. If omitted or set to `<= 0`, PicoClaw uses the default timeout (120s). **3. Get API Keys** @@ -915,7 +917,8 @@ This design also enables **multi-agent support** with flexible provider selectio "model_name": "my-custom-model", "model": "openai/custom-model", "api_base": "https://my-proxy.com/v1", - "api_key": "sk-..." + "api_key": "sk-...", + "request_timeout": 300 } ``` diff --git a/README.pt-br.md b/README.pt-br.md index 0115b7f89..1dbee5201 100644 --- a/README.pt-br.md +++ b/README.pt-br.md @@ -222,6 +222,7 @@ picoclaw onboard "model_name": "gpt4", "model": "openai/gpt-5.2", "api_key": "sk-your-openai-key", + "request_timeout": 300, "api_base": "https://api.openai.com/v1" } ], @@ -246,6 +247,9 @@ picoclaw onboard } ``` +> **Novo**: O formato de configuração `model_list` permite adicionar provedores sem alterar código. Veja [Configuração de Modelo](#configuração-de-modelo-model_list) para detalhes. +> `request_timeout` é opcional e usa segundos. Se omitido ou definido como `<= 0`, o PicoClaw usa o timeout padrão (120s). + **3. Obter API Keys** * **Provedor de LLM**: [OpenRouter](https://openrouter.ai/keys) · [Zhipu](https://open.bigmodel.cn/usercenter/proj-mgmt/apikeys) · [Anthropic](https://console.anthropic.com) · [OpenAI](https://platform.openai.com) · [Gemini](https://aistudio.google.com/api-keys) @@ -973,6 +977,17 @@ Este design também possibilita o **suporte multi-agent** com seleção flexíve ``` > Execute `picoclaw auth login --provider anthropic` para configurar credenciais OAuth. +**Proxy/API personalizada** +```json +{ + "model_name": "my-custom-model", + "model": "openai/custom-model", + "api_base": "https://my-proxy.com/v1", + "api_key": "sk-...", + "request_timeout": 300 +} +``` + #### Balanceamento de Carga Configure vários endpoints para o mesmo nome de modelo—PicoClaw fará round-robin automaticamente entre eles: diff --git a/README.vi.md b/README.vi.md index 015bc264e..0dd4994c2 100644 --- a/README.vi.md +++ b/README.vi.md @@ -202,6 +202,7 @@ picoclaw onboard "model_name": "gpt4", "model": "openai/gpt-5.2", "api_key": "sk-your-openai-key", + "request_timeout": 300, "api_base": "https://api.openai.com/v1" } ], @@ -220,6 +221,9 @@ picoclaw onboard } ``` +> **Mới**: Định dạng cấu hình `model_list` cho phép thêm nhà cung cấp mà không cần thay đổi mã nguồn. Xem [Cấu hình Mô hình](#cấu-hình-mô-hình-model_list) để biết chi tiết. +> `request_timeout` là tùy chọn và dùng đơn vị giây. Nếu bỏ qua hoặc đặt `<= 0`, PicoClaw sẽ dùng timeout mặc định (120s). + **3. Lấy API Key** * **Nhà cung cấp LLM**: [OpenRouter](https://openrouter.ai/keys) · [Zhipu](https://open.bigmodel.cn/usercenter/proj-mgmt/apikeys) · [Anthropic](https://console.anthropic.com) · [OpenAI](https://platform.openai.com) · [Gemini](https://aistudio.google.com/api-keys) @@ -944,6 +948,17 @@ Thiết kế này cũng cho phép **hỗ trợ đa tác nhân** với lựa ch ``` > Chạy `picoclaw auth login --provider anthropic` để thiết lập thông tin xác thực OAuth. +**Proxy/API tùy chỉnh** +```json +{ + "model_name": "my-custom-model", + "model": "openai/custom-model", + "api_base": "https://my-proxy.com/v1", + "api_key": "sk-...", + "request_timeout": 300 +} +``` + #### Cân bằng Tải tải Định cấu hình nhiều endpoint cho cùng một tên mô hình—PicoClaw sẽ tự động phân phối round-robin giữa chúng: diff --git a/README.zh.md b/README.zh.md index 4f4bde46a..8ce1ad2ee 100644 --- a/README.zh.md +++ b/README.zh.md @@ -234,7 +234,8 @@ picoclaw onboard { "model_name": "gpt4", "model": "openai/gpt-5.2", - "api_key": "your-api-key" + "api_key": "your-api-key", + "request_timeout": 300 }, { "model_name": "claude-sonnet-4.6", @@ -263,6 +264,7 @@ picoclaw onboard ``` > **新功能**: `model_list` 配置格式支持零代码添加 provider。详见[模型配置](#模型配置-model_list)章节。 +> `request_timeout` 为可选项,单位为秒。若省略或设置为 `<= 0`,PicoClaw 使用默认超时(120 秒)。 **3. 获取 API Key** @@ -550,7 +552,8 @@ Agent 读取 HEARTBEAT.md "model_name": "my-custom-model", "model": "openai/custom-model", "api_base": "https://my-proxy.com/v1", - "api_key": "sk-..." + "api_key": "sk-...", + "request_timeout": 300 } ``` diff --git a/docs/migration/model-list-migration.md b/docs/migration/model-list-migration.md index 589dfc043..0d4af719c 100644 --- a/docs/migration/model-list-migration.md +++ b/docs/migration/model-list-migration.md @@ -117,6 +117,7 @@ The `model` field uses a protocol prefix format: `[protocol/]model-identifier` | `connect_mode` | No | Connection mode for CLI providers: `stdio`, `grpc` | | `rpm` | No | Requests per minute limit | | `max_tokens_field` | No | Field name for max tokens | +| `request_timeout` | No | HTTP request timeout in seconds; `<=0` uses default `120s` | *`api_key` is required for HTTP-based protocols unless `api_base` points to a local server. diff --git a/pkg/channels/telegram.go b/pkg/channels/telegram.go index 524494849..6592d9bc0 100644 --- a/pkg/channels/telegram.go +++ b/pkg/channels/telegram.go @@ -23,6 +23,19 @@ import ( "github.com/sipeed/picoclaw/pkg/voice" ) +var ( + reHeading = regexp.MustCompile(`^#{1,6}\s+(.+)$`) + reBlockquote = regexp.MustCompile(`^>\s*(.*)$`) + reLink = regexp.MustCompile(`\[([^\]]+)\]\(([^)]+)\)`) + reBoldStar = regexp.MustCompile(`\*\*(.+?)\*\*`) + reBoldUnder = regexp.MustCompile(`__(.+?)__`) + reItalic = regexp.MustCompile(`_([^_]+)_`) + reStrike = regexp.MustCompile(`~~(.+?)~~`) + reListItem = regexp.MustCompile(`^[-*]\s+`) + reCodeBlock = regexp.MustCompile("```[\\w]*\\n?([\\s\\S]*?)```") + reInlineCode = regexp.MustCompile("`([^`]+)`") +) + type TelegramChannel struct { *BaseChannel bot *telego.Bot @@ -431,19 +444,18 @@ func markdownToTelegramHTML(text string) string { inlineCodes := extractInlineCodes(text) text = inlineCodes.text - text = regexp.MustCompile(`^#{1,6}\s+(.+)$`).ReplaceAllString(text, "$1") + text = reHeading.ReplaceAllString(text, "$1") - text = regexp.MustCompile(`^>\s*(.*)$`).ReplaceAllString(text, "$1") + text = reBlockquote.ReplaceAllString(text, "$1") text = escapeHTML(text) - text = regexp.MustCompile(`\[([^\]]+)\]\(([^)]+)\)`).ReplaceAllString(text, `$1`) + text = reLink.ReplaceAllString(text, `$1`) - text = regexp.MustCompile(`\*\*(.+?)\*\*`).ReplaceAllString(text, "$1") + text = reBoldStar.ReplaceAllString(text, "$1") - text = regexp.MustCompile(`__(.+?)__`).ReplaceAllString(text, "$1") + text = reBoldUnder.ReplaceAllString(text, "$1") - reItalic := regexp.MustCompile(`_([^_]+)_`) text = reItalic.ReplaceAllStringFunc(text, func(s string) string { match := reItalic.FindStringSubmatch(s) if len(match) < 2 { @@ -452,9 +464,9 @@ func markdownToTelegramHTML(text string) string { return "" + match[1] + "" }) - text = regexp.MustCompile(`~~(.+?)~~`).ReplaceAllString(text, "$1") + text = reStrike.ReplaceAllString(text, "$1") - text = regexp.MustCompile(`^[-*]\s+`).ReplaceAllString(text, "• ") + text = reListItem.ReplaceAllString(text, "• ") for i, code := range inlineCodes.codes { escaped := escapeHTML(code) @@ -479,8 +491,7 @@ type codeBlockMatch struct { } func extractCodeBlocks(text string) codeBlockMatch { - re := regexp.MustCompile("```[\\w]*\\n?([\\s\\S]*?)```") - matches := re.FindAllStringSubmatch(text, -1) + matches := reCodeBlock.FindAllStringSubmatch(text, -1) codes := make([]string, 0, len(matches)) for _, match := range matches { @@ -488,7 +499,7 @@ func extractCodeBlocks(text string) codeBlockMatch { } i := 0 - text = re.ReplaceAllStringFunc(text, func(m string) string { + text = reCodeBlock.ReplaceAllStringFunc(text, func(m string) string { placeholder := fmt.Sprintf("\x00CB%d\x00", i) i++ return placeholder @@ -503,8 +514,7 @@ type inlineCodeMatch struct { } func extractInlineCodes(text string) inlineCodeMatch { - re := regexp.MustCompile("`([^`]+)`") - matches := re.FindAllStringSubmatch(text, -1) + matches := reInlineCode.FindAllStringSubmatch(text, -1) codes := make([]string, 0, len(matches)) for _, match := range matches { @@ -512,7 +522,7 @@ func extractInlineCodes(text string) inlineCodeMatch { } i := 0 - text = re.ReplaceAllStringFunc(text, func(m string) string { + text = reInlineCode.ReplaceAllStringFunc(text, func(m string) string { placeholder := fmt.Sprintf("\x00IC%d\x00", i) i++ return placeholder diff --git a/pkg/config/config.go b/pkg/config/config.go index fa9ec93da..ddfa35dc9 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -371,11 +371,12 @@ func (p ProvidersConfig) MarshalJSON() ([]byte, error) { } type ProviderConfig struct { - APIKey string `json:"api_key" env:"PICOCLAW_PROVIDERS_{{.Name}}_API_KEY"` - APIBase string `json:"api_base" env:"PICOCLAW_PROVIDERS_{{.Name}}_API_BASE"` - Proxy string `json:"proxy,omitempty" env:"PICOCLAW_PROVIDERS_{{.Name}}_PROXY"` - AuthMethod string `json:"auth_method,omitempty" env:"PICOCLAW_PROVIDERS_{{.Name}}_AUTH_METHOD"` - ConnectMode string `json:"connect_mode,omitempty" env:"PICOCLAW_PROVIDERS_{{.Name}}_CONNECT_MODE"` // only for Github Copilot, `stdio` or `grpc` + APIKey string `json:"api_key" env:"PICOCLAW_PROVIDERS_{{.Name}}_API_KEY"` + APIBase string `json:"api_base" env:"PICOCLAW_PROVIDERS_{{.Name}}_API_BASE"` + Proxy string `json:"proxy,omitempty" env:"PICOCLAW_PROVIDERS_{{.Name}}_PROXY"` + RequestTimeout int `json:"request_timeout,omitempty" env:"PICOCLAW_PROVIDERS_{{.Name}}_REQUEST_TIMEOUT"` + AuthMethod string `json:"auth_method,omitempty" env:"PICOCLAW_PROVIDERS_{{.Name}}_AUTH_METHOD"` + ConnectMode string `json:"connect_mode,omitempty" env:"PICOCLAW_PROVIDERS_{{.Name}}_CONNECT_MODE"` // only for Github Copilot, `stdio` or `grpc` } type OpenAIProviderConfig struct { @@ -406,6 +407,7 @@ type ModelConfig struct { // Optional optimizations RPM int `json:"rpm,omitempty"` // Requests per minute limit MaxTokensField string `json:"max_tokens_field,omitempty"` // Field name for max tokens (e.g., "max_completion_tokens") + RequestTimeout int `json:"request_timeout,omitempty"` } // Validate checks if the ModelConfig has all required fields. diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go index 223ac798d..bf56b7f34 100644 --- a/pkg/config/config_test.go +++ b/pkg/config/config_test.go @@ -413,3 +413,12 @@ func TestLoadConfig_WebToolsProxy(t *testing.T) { t.Fatalf("Tools.Web.Proxy = %q, want %q", cfg.Tools.Web.Proxy, "http://127.0.0.1:7890") } } + +// TestDefaultConfig_DMScope verifies the default dm_scope value +func TestDefaultConfig_DMScope(t *testing.T) { + cfg := DefaultConfig() + + if cfg.Session.DMScope != "per-channel-peer" { + t.Errorf("Session.DMScope = %q, want 'per-channel-peer'", cfg.Session.DMScope) + } +} diff --git a/pkg/config/defaults.go b/pkg/config/defaults.go index cc6de9399..cf799140d 100644 --- a/pkg/config/defaults.go +++ b/pkg/config/defaults.go @@ -21,7 +21,7 @@ func DefaultConfig() *Config { }, Bindings: []AgentBinding{}, Session: SessionConfig{ - DMScope: "main", + DMScope: "per-channel-peer", }, Channels: ChannelsConfig{ WhatsApp: WhatsAppConfig{ diff --git a/pkg/config/migration.go b/pkg/config/migration.go index 70e1de438..5deb09270 100644 --- a/pkg/config/migration.go +++ b/pkg/config/migration.go @@ -60,12 +60,13 @@ func ConvertProvidersToModelList(cfg *Config) []ModelConfig { return ModelConfig{}, false } return ModelConfig{ - ModelName: "openai", - Model: "openai/gpt-5.2", - APIKey: p.OpenAI.APIKey, - APIBase: p.OpenAI.APIBase, - Proxy: p.OpenAI.Proxy, - AuthMethod: p.OpenAI.AuthMethod, + ModelName: "openai", + Model: "openai/gpt-5.2", + APIKey: p.OpenAI.APIKey, + APIBase: p.OpenAI.APIBase, + Proxy: p.OpenAI.Proxy, + RequestTimeout: p.OpenAI.RequestTimeout, + AuthMethod: p.OpenAI.AuthMethod, }, true }, }, @@ -77,12 +78,13 @@ func ConvertProvidersToModelList(cfg *Config) []ModelConfig { return ModelConfig{}, false } return ModelConfig{ - ModelName: "anthropic", - Model: "anthropic/claude-sonnet-4.6", - APIKey: p.Anthropic.APIKey, - APIBase: p.Anthropic.APIBase, - Proxy: p.Anthropic.Proxy, - AuthMethod: p.Anthropic.AuthMethod, + ModelName: "anthropic", + Model: "anthropic/claude-sonnet-4.6", + APIKey: p.Anthropic.APIKey, + APIBase: p.Anthropic.APIBase, + Proxy: p.Anthropic.Proxy, + RequestTimeout: p.Anthropic.RequestTimeout, + AuthMethod: p.Anthropic.AuthMethod, }, true }, }, @@ -94,11 +96,12 @@ func ConvertProvidersToModelList(cfg *Config) []ModelConfig { return ModelConfig{}, false } return ModelConfig{ - ModelName: "openrouter", - Model: "openrouter/auto", - APIKey: p.OpenRouter.APIKey, - APIBase: p.OpenRouter.APIBase, - Proxy: p.OpenRouter.Proxy, + ModelName: "openrouter", + Model: "openrouter/auto", + APIKey: p.OpenRouter.APIKey, + APIBase: p.OpenRouter.APIBase, + Proxy: p.OpenRouter.Proxy, + RequestTimeout: p.OpenRouter.RequestTimeout, }, true }, }, @@ -110,11 +113,12 @@ func ConvertProvidersToModelList(cfg *Config) []ModelConfig { return ModelConfig{}, false } return ModelConfig{ - ModelName: "groq", - Model: "groq/llama-3.1-70b-versatile", - APIKey: p.Groq.APIKey, - APIBase: p.Groq.APIBase, - Proxy: p.Groq.Proxy, + ModelName: "groq", + Model: "groq/llama-3.1-70b-versatile", + APIKey: p.Groq.APIKey, + APIBase: p.Groq.APIBase, + Proxy: p.Groq.Proxy, + RequestTimeout: p.Groq.RequestTimeout, }, true }, }, @@ -126,11 +130,12 @@ func ConvertProvidersToModelList(cfg *Config) []ModelConfig { return ModelConfig{}, false } return ModelConfig{ - ModelName: "zhipu", - Model: "zhipu/glm-4", - APIKey: p.Zhipu.APIKey, - APIBase: p.Zhipu.APIBase, - Proxy: p.Zhipu.Proxy, + ModelName: "zhipu", + Model: "zhipu/glm-4", + APIKey: p.Zhipu.APIKey, + APIBase: p.Zhipu.APIBase, + Proxy: p.Zhipu.Proxy, + RequestTimeout: p.Zhipu.RequestTimeout, }, true }, }, @@ -142,11 +147,12 @@ func ConvertProvidersToModelList(cfg *Config) []ModelConfig { return ModelConfig{}, false } return ModelConfig{ - ModelName: "vllm", - Model: "vllm/auto", - APIKey: p.VLLM.APIKey, - APIBase: p.VLLM.APIBase, - Proxy: p.VLLM.Proxy, + ModelName: "vllm", + Model: "vllm/auto", + APIKey: p.VLLM.APIKey, + APIBase: p.VLLM.APIBase, + Proxy: p.VLLM.Proxy, + RequestTimeout: p.VLLM.RequestTimeout, }, true }, }, @@ -158,11 +164,12 @@ func ConvertProvidersToModelList(cfg *Config) []ModelConfig { return ModelConfig{}, false } return ModelConfig{ - ModelName: "gemini", - Model: "gemini/gemini-pro", - APIKey: p.Gemini.APIKey, - APIBase: p.Gemini.APIBase, - Proxy: p.Gemini.Proxy, + ModelName: "gemini", + Model: "gemini/gemini-pro", + APIKey: p.Gemini.APIKey, + APIBase: p.Gemini.APIBase, + Proxy: p.Gemini.Proxy, + RequestTimeout: p.Gemini.RequestTimeout, }, true }, }, @@ -174,11 +181,12 @@ func ConvertProvidersToModelList(cfg *Config) []ModelConfig { return ModelConfig{}, false } return ModelConfig{ - ModelName: "nvidia", - Model: "nvidia/meta/llama-3.1-8b-instruct", - APIKey: p.Nvidia.APIKey, - APIBase: p.Nvidia.APIBase, - Proxy: p.Nvidia.Proxy, + ModelName: "nvidia", + Model: "nvidia/meta/llama-3.1-8b-instruct", + APIKey: p.Nvidia.APIKey, + APIBase: p.Nvidia.APIBase, + Proxy: p.Nvidia.Proxy, + RequestTimeout: p.Nvidia.RequestTimeout, }, true }, }, @@ -190,11 +198,12 @@ func ConvertProvidersToModelList(cfg *Config) []ModelConfig { return ModelConfig{}, false } return ModelConfig{ - ModelName: "ollama", - Model: "ollama/llama3", - APIKey: p.Ollama.APIKey, - APIBase: p.Ollama.APIBase, - Proxy: p.Ollama.Proxy, + ModelName: "ollama", + Model: "ollama/llama3", + APIKey: p.Ollama.APIKey, + APIBase: p.Ollama.APIBase, + Proxy: p.Ollama.Proxy, + RequestTimeout: p.Ollama.RequestTimeout, }, true }, }, @@ -206,11 +215,12 @@ func ConvertProvidersToModelList(cfg *Config) []ModelConfig { return ModelConfig{}, false } return ModelConfig{ - ModelName: "moonshot", - Model: "moonshot/kimi", - APIKey: p.Moonshot.APIKey, - APIBase: p.Moonshot.APIBase, - Proxy: p.Moonshot.Proxy, + ModelName: "moonshot", + Model: "moonshot/kimi", + APIKey: p.Moonshot.APIKey, + APIBase: p.Moonshot.APIBase, + Proxy: p.Moonshot.Proxy, + RequestTimeout: p.Moonshot.RequestTimeout, }, true }, }, @@ -222,11 +232,12 @@ func ConvertProvidersToModelList(cfg *Config) []ModelConfig { return ModelConfig{}, false } return ModelConfig{ - ModelName: "shengsuanyun", - Model: "shengsuanyun/auto", - APIKey: p.ShengSuanYun.APIKey, - APIBase: p.ShengSuanYun.APIBase, - Proxy: p.ShengSuanYun.Proxy, + ModelName: "shengsuanyun", + Model: "shengsuanyun/auto", + APIKey: p.ShengSuanYun.APIKey, + APIBase: p.ShengSuanYun.APIBase, + Proxy: p.ShengSuanYun.Proxy, + RequestTimeout: p.ShengSuanYun.RequestTimeout, }, true }, }, @@ -238,11 +249,12 @@ func ConvertProvidersToModelList(cfg *Config) []ModelConfig { return ModelConfig{}, false } return ModelConfig{ - ModelName: "deepseek", - Model: "deepseek/deepseek-chat", - APIKey: p.DeepSeek.APIKey, - APIBase: p.DeepSeek.APIBase, - Proxy: p.DeepSeek.Proxy, + ModelName: "deepseek", + Model: "deepseek/deepseek-chat", + APIKey: p.DeepSeek.APIKey, + APIBase: p.DeepSeek.APIBase, + Proxy: p.DeepSeek.Proxy, + RequestTimeout: p.DeepSeek.RequestTimeout, }, true }, }, @@ -254,11 +266,12 @@ func ConvertProvidersToModelList(cfg *Config) []ModelConfig { return ModelConfig{}, false } return ModelConfig{ - ModelName: "cerebras", - Model: "cerebras/llama-3.3-70b", - APIKey: p.Cerebras.APIKey, - APIBase: p.Cerebras.APIBase, - Proxy: p.Cerebras.Proxy, + ModelName: "cerebras", + Model: "cerebras/llama-3.3-70b", + APIKey: p.Cerebras.APIKey, + APIBase: p.Cerebras.APIBase, + Proxy: p.Cerebras.Proxy, + RequestTimeout: p.Cerebras.RequestTimeout, }, true }, }, @@ -270,11 +283,12 @@ func ConvertProvidersToModelList(cfg *Config) []ModelConfig { return ModelConfig{}, false } return ModelConfig{ - ModelName: "volcengine", - Model: "volcengine/doubao-pro", - APIKey: p.VolcEngine.APIKey, - APIBase: p.VolcEngine.APIBase, - Proxy: p.VolcEngine.Proxy, + ModelName: "volcengine", + Model: "volcengine/doubao-pro", + APIKey: p.VolcEngine.APIKey, + APIBase: p.VolcEngine.APIBase, + Proxy: p.VolcEngine.Proxy, + RequestTimeout: p.VolcEngine.RequestTimeout, }, true }, }, @@ -316,11 +330,12 @@ func ConvertProvidersToModelList(cfg *Config) []ModelConfig { return ModelConfig{}, false } return ModelConfig{ - ModelName: "qwen", - Model: "qwen/qwen-max", - APIKey: p.Qwen.APIKey, - APIBase: p.Qwen.APIBase, - Proxy: p.Qwen.Proxy, + ModelName: "qwen", + Model: "qwen/qwen-max", + APIKey: p.Qwen.APIKey, + APIBase: p.Qwen.APIBase, + Proxy: p.Qwen.Proxy, + RequestTimeout: p.Qwen.RequestTimeout, }, true }, }, @@ -332,11 +347,12 @@ func ConvertProvidersToModelList(cfg *Config) []ModelConfig { return ModelConfig{}, false } return ModelConfig{ - ModelName: "mistral", - Model: "mistral/mistral-small-latest", - APIKey: p.Mistral.APIKey, - APIBase: p.Mistral.APIBase, - Proxy: p.Mistral.Proxy, + ModelName: "mistral", + Model: "mistral/mistral-small-latest", + APIKey: p.Mistral.APIKey, + APIBase: p.Mistral.APIBase, + Proxy: p.Mistral.Proxy, + RequestTimeout: p.Mistral.RequestTimeout, }, true }, }, diff --git a/pkg/config/migration_test.go b/pkg/config/migration_test.go index 42165cb71..db8f4657d 100644 --- a/pkg/config/migration_test.go +++ b/pkg/config/migration_test.go @@ -166,6 +166,27 @@ func TestConvertProvidersToModelList_Proxy(t *testing.T) { } } +func TestConvertProvidersToModelList_RequestTimeout(t *testing.T) { + cfg := &Config{ + Providers: ProvidersConfig{ + Ollama: ProviderConfig{ + APIKey: "ollama-key", + RequestTimeout: 300, + }, + }, + } + + result := ConvertProvidersToModelList(cfg) + + if len(result) != 1 { + t.Fatalf("len(result) = %d, want 1", len(result)) + } + + if result[0].RequestTimeout != 300 { + t.Errorf("RequestTimeout = %d, want %d", result[0].RequestTimeout, 300) + } +} + func TestConvertProvidersToModelList_AuthMethod(t *testing.T) { cfg := &Config{ Providers: ProvidersConfig{ diff --git a/pkg/config/model_config_test.go b/pkg/config/model_config_test.go index 99eea2782..084f50a82 100644 --- a/pkg/config/model_config_test.go +++ b/pkg/config/model_config_test.go @@ -365,3 +365,38 @@ func TestConfig_ValidateModelList(t *testing.T) { }) } } + +func TestModelConfig_RequestTimeoutParsing(t *testing.T) { + jsonData := `{ + "model_name": "slow-local", + "model": "openai/local-model", + "api_base": "http://localhost:11434/v1", + "request_timeout": 300 + }` + + var cfg ModelConfig + if err := json.Unmarshal([]byte(jsonData), &cfg); err != nil { + t.Fatalf("Unmarshal() error = %v", err) + } + + if cfg.RequestTimeout != 300 { + t.Fatalf("RequestTimeout = %d, want 300", cfg.RequestTimeout) + } +} + +func TestModelConfig_RequestTimeoutDefaultZeroValue(t *testing.T) { + jsonData := `{ + "model_name": "default-timeout", + "model": "openai/gpt-4o", + "api_key": "test-key" + }` + + var cfg ModelConfig + if err := json.Unmarshal([]byte(jsonData), &cfg); err != nil { + t.Fatalf("Unmarshal() error = %v", err) + } + + if cfg.RequestTimeout != 0 { + t.Fatalf("RequestTimeout = %d, want 0", cfg.RequestTimeout) + } +} diff --git a/pkg/providers/factory_provider.go b/pkg/providers/factory_provider.go index 7d5566eef..53f7a08a0 100644 --- a/pkg/providers/factory_provider.go +++ b/pkg/providers/factory_provider.go @@ -84,7 +84,13 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err if apiBase == "" { apiBase = getDefaultAPIBase(protocol) } - return NewHTTPProviderWithMaxTokensField(cfg.APIKey, apiBase, cfg.Proxy, cfg.MaxTokensField), modelID, nil + return NewHTTPProviderWithMaxTokensFieldAndRequestTimeout( + cfg.APIKey, + apiBase, + cfg.Proxy, + cfg.MaxTokensField, + cfg.RequestTimeout, + ), modelID, nil case "openrouter", "groq", "zhipu", "gemini", "nvidia", "ollama", "moonshot", "shengsuanyun", "deepseek", "cerebras", @@ -97,7 +103,13 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err if apiBase == "" { apiBase = getDefaultAPIBase(protocol) } - return NewHTTPProviderWithMaxTokensField(cfg.APIKey, apiBase, cfg.Proxy, cfg.MaxTokensField), modelID, nil + return NewHTTPProviderWithMaxTokensFieldAndRequestTimeout( + cfg.APIKey, + apiBase, + cfg.Proxy, + cfg.MaxTokensField, + cfg.RequestTimeout, + ), modelID, nil case "anthropic": if cfg.AuthMethod == "oauth" || cfg.AuthMethod == "token" { @@ -116,7 +128,13 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err if cfg.APIKey == "" { return nil, "", fmt.Errorf("api_key is required for anthropic protocol (model: %s)", cfg.Model) } - return NewHTTPProviderWithMaxTokensField(cfg.APIKey, apiBase, cfg.Proxy, cfg.MaxTokensField), modelID, nil + return NewHTTPProviderWithMaxTokensFieldAndRequestTimeout( + cfg.APIKey, + apiBase, + cfg.Proxy, + cfg.MaxTokensField, + cfg.RequestTimeout, + ), modelID, nil case "antigravity": return NewAntigravityProvider(), modelID, nil diff --git a/pkg/providers/factory_provider_test.go b/pkg/providers/factory_provider_test.go index 6b133101a..e0c0eddef 100644 --- a/pkg/providers/factory_provider_test.go +++ b/pkg/providers/factory_provider_test.go @@ -6,7 +6,11 @@ package providers import ( + "net/http" + "net/http/httptest" + "strings" "testing" + "time" "github.com/sipeed/picoclaw/pkg/config" ) @@ -247,3 +251,42 @@ func TestCreateProviderFromConfig_EmptyModel(t *testing.T) { t.Fatal("CreateProviderFromConfig() expected error for empty model") } } + +func TestCreateProviderFromConfig_RequestTimeoutPropagation(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(1500 * time.Millisecond) + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"choices":[{"message":{"content":"ok"},"finish_reason":"stop"}]}`)) + })) + defer server.Close() + + cfg := &config.ModelConfig{ + ModelName: "test-timeout", + Model: "openai/gpt-4o", + APIBase: server.URL, + RequestTimeout: 1, + } + + provider, modelID, err := CreateProviderFromConfig(cfg) + if err != nil { + t.Fatalf("CreateProviderFromConfig() error = %v", err) + } + if modelID != "gpt-4o" { + t.Fatalf("modelID = %q, want %q", modelID, "gpt-4o") + } + + _, err = provider.Chat( + t.Context(), + []Message{{Role: "user", Content: "hi"}}, + nil, + modelID, + nil, + ) + if err == nil { + t.Fatal("Chat() expected timeout error, got nil") + } + errMsg := err.Error() + if !strings.Contains(errMsg, "context deadline exceeded") && !strings.Contains(errMsg, "Client.Timeout exceeded") { + t.Fatalf("Chat() error = %q, want timeout-related error", errMsg) + } +} diff --git a/pkg/providers/http_provider.go b/pkg/providers/http_provider.go index d0c4344f3..5c328f418 100644 --- a/pkg/providers/http_provider.go +++ b/pkg/providers/http_provider.go @@ -8,6 +8,7 @@ package providers import ( "context" + "time" "github.com/sipeed/picoclaw/pkg/providers/openai_compat" ) @@ -23,8 +24,21 @@ func NewHTTPProvider(apiKey, apiBase, proxy string) *HTTPProvider { } func NewHTTPProviderWithMaxTokensField(apiKey, apiBase, proxy, maxTokensField string) *HTTPProvider { + return NewHTTPProviderWithMaxTokensFieldAndRequestTimeout(apiKey, apiBase, proxy, maxTokensField, 0) +} + +func NewHTTPProviderWithMaxTokensFieldAndRequestTimeout( + apiKey, apiBase, proxy, maxTokensField string, + requestTimeoutSeconds int, +) *HTTPProvider { return &HTTPProvider{ - delegate: openai_compat.NewProviderWithMaxTokensField(apiKey, apiBase, proxy, maxTokensField), + delegate: openai_compat.NewProvider( + apiKey, + apiBase, + proxy, + openai_compat.WithMaxTokensField(maxTokensField), + openai_compat.WithRequestTimeout(time.Duration(requestTimeoutSeconds)*time.Second), + ), } } diff --git a/pkg/providers/openai_compat/provider.go b/pkg/providers/openai_compat/provider.go index 087d3506e..7dace71f2 100644 --- a/pkg/providers/openai_compat/provider.go +++ b/pkg/providers/openai_compat/provider.go @@ -34,13 +34,27 @@ type Provider struct { httpClient *http.Client } -func NewProvider(apiKey, apiBase, proxy string) *Provider { - return NewProviderWithMaxTokensField(apiKey, apiBase, proxy, "") +type Option func(*Provider) + +const defaultRequestTimeout = 120 * time.Second + +func WithMaxTokensField(maxTokensField string) Option { + return func(p *Provider) { + p.maxTokensField = maxTokensField + } } -func NewProviderWithMaxTokensField(apiKey, apiBase, proxy, maxTokensField string) *Provider { +func WithRequestTimeout(timeout time.Duration) Option { + return func(p *Provider) { + if timeout > 0 { + p.httpClient.Timeout = timeout + } + } +} + +func NewProvider(apiKey, apiBase, proxy string, opts ...Option) *Provider { client := &http.Client{ - Timeout: 120 * time.Second, + Timeout: defaultRequestTimeout, } if proxy != "" { @@ -54,12 +68,36 @@ func NewProviderWithMaxTokensField(apiKey, apiBase, proxy, maxTokensField string } } - return &Provider{ - apiKey: apiKey, - apiBase: strings.TrimRight(apiBase, "/"), - maxTokensField: maxTokensField, - httpClient: client, + p := &Provider{ + apiKey: apiKey, + apiBase: strings.TrimRight(apiBase, "/"), + httpClient: client, } + + for _, opt := range opts { + if opt != nil { + opt(p) + } + } + + return p +} + +func NewProviderWithMaxTokensField(apiKey, apiBase, proxy, maxTokensField string) *Provider { + return NewProvider(apiKey, apiBase, proxy, WithMaxTokensField(maxTokensField)) +} + +func NewProviderWithMaxTokensFieldAndTimeout( + apiKey, apiBase, proxy, maxTokensField string, + requestTimeoutSeconds int, +) *Provider { + return NewProvider( + apiKey, + apiBase, + proxy, + WithMaxTokensField(maxTokensField), + WithRequestTimeout(time.Duration(requestTimeoutSeconds)*time.Second), + ) } func (p *Provider) Chat( diff --git a/pkg/providers/openai_compat/provider_test.go b/pkg/providers/openai_compat/provider_test.go index 594a48213..7247fea3e 100644 --- a/pkg/providers/openai_compat/provider_test.go +++ b/pkg/providers/openai_compat/provider_test.go @@ -6,6 +6,7 @@ import ( "net/http/httptest" "net/url" "testing" + "time" ) func TestProviderChat_UsesMaxCompletionTokensForGLM(t *testing.T) { @@ -325,3 +326,38 @@ func TestNormalizeModel_UsesAPIBase(t *testing.T) { t.Fatalf("normalizeModel(openrouter) = %q, want %q", got, "openrouter/auto") } } + +func TestProvider_RequestTimeoutDefault(t *testing.T) { + p := NewProviderWithMaxTokensFieldAndTimeout("key", "https://example.com/v1", "", "", 0) + if p.httpClient.Timeout != defaultRequestTimeout { + t.Fatalf("http timeout = %v, want %v", p.httpClient.Timeout, defaultRequestTimeout) + } +} + +func TestProvider_RequestTimeoutOverride(t *testing.T) { + p := NewProviderWithMaxTokensFieldAndTimeout("key", "https://example.com/v1", "", "", 300) + if p.httpClient.Timeout != 300*time.Second { + t.Fatalf("http timeout = %v, want %v", p.httpClient.Timeout, 300*time.Second) + } +} + +func TestProvider_FunctionalOptionMaxTokensField(t *testing.T) { + p := NewProvider("key", "https://example.com/v1", "", WithMaxTokensField("max_completion_tokens")) + if p.maxTokensField != "max_completion_tokens" { + t.Fatalf("maxTokensField = %q, want %q", p.maxTokensField, "max_completion_tokens") + } +} + +func TestProvider_FunctionalOptionRequestTimeout(t *testing.T) { + p := NewProvider("key", "https://example.com/v1", "", WithRequestTimeout(45*time.Second)) + if p.httpClient.Timeout != 45*time.Second { + t.Fatalf("http timeout = %v, want %v", p.httpClient.Timeout, 45*time.Second) + } +} + +func TestProvider_FunctionalOptionRequestTimeoutNonPositive(t *testing.T) { + p := NewProvider("key", "https://example.com/v1", "", WithRequestTimeout(-1*time.Second)) + if p.httpClient.Timeout != defaultRequestTimeout { + t.Fatalf("http timeout = %v, want %v", p.httpClient.Timeout, defaultRequestTimeout) + } +} diff --git a/pkg/skills/installer.go b/pkg/skills/installer.go index d6ff5f3a3..31b5a3dbd 100644 --- a/pkg/skills/installer.go +++ b/pkg/skills/installer.go @@ -10,6 +10,8 @@ import ( "path/filepath" "time" + "github.com/sipeed/picoclaw/pkg/utils" + "github.com/sipeed/picoclaw/pkg/fileutil" ) @@ -46,7 +48,7 @@ func (si *SkillInstaller) InstallFromGitHub(ctx context.Context, repo string) er return fmt.Errorf("failed to create request: %w", err) } - resp, err := client.Do(req) + resp, err := utils.DoRequestWithRetry(client, req) if err != nil { return fmt.Errorf("failed to fetch skill: %w", err) } @@ -66,7 +68,7 @@ func (si *SkillInstaller) InstallFromGitHub(ctx context.Context, repo string) er } skillPath := filepath.Join(skillDir, "SKILL.md") - + // Use unified atomic write utility with explicit sync for flash storage reliability. if err := fileutil.WriteFileAtomic(skillPath, body, 0o600); err != nil { return fmt.Errorf("failed to write skill file: %w", err) @@ -98,7 +100,7 @@ func (si *SkillInstaller) ListAvailableSkills(ctx context.Context) ([]AvailableS return nil, fmt.Errorf("failed to create request: %w", err) } - resp, err := client.Do(req) + resp, err := utils.DoRequestWithRetry(client, req) if err != nil { return nil, fmt.Errorf("failed to fetch skills list: %w", err) } diff --git a/pkg/skills/loader.go b/pkg/skills/loader.go index 5749d8983..67d3e70e0 100644 --- a/pkg/skills/loader.go +++ b/pkg/skills/loader.go @@ -13,7 +13,11 @@ import ( "github.com/sipeed/picoclaw/pkg/logger" ) -var namePattern = regexp.MustCompile(`^[a-zA-Z0-9]+(-[a-zA-Z0-9]+)*$`) +var ( + namePattern = regexp.MustCompile(`^[a-zA-Z0-9]+(-[a-zA-Z0-9]+)*$`) + reFrontmatter = regexp.MustCompile(`(?s)^---(?:\r\n|\n|\r)(.*?)(?:\r\n|\n|\r)---`) + reStripFrontmatter = regexp.MustCompile(`(?s)^---(?:\r\n|\n|\r)(.*?)(?:\r\n|\n|\r)---(?:\r\n|\n|\r)*`) +) const ( MaxNameLength = 64 @@ -257,10 +261,7 @@ func (sl *SkillsLoader) parseSimpleYAML(content string) map[string]string { func (sl *SkillsLoader) extractFrontmatter(content string) string { // Support \n (Unix), \r\n (Windows), and \r (classic Mac) line endings for frontmatter blocks - // (?s) enables DOTALL so . matches newlines; - // ^--- at start, then ... --- at start of line, honoring all three line ending types - re := regexp.MustCompile(`(?s)^---(?:\r\n|\n|\r)(.*?)(?:\r\n|\n|\r)---`) - match := re.FindStringSubmatch(content) + match := reFrontmatter.FindStringSubmatch(content) if len(match) > 1 { return match[1] } @@ -268,12 +269,7 @@ func (sl *SkillsLoader) extractFrontmatter(content string) string { } func (sl *SkillsLoader) stripFrontmatter(content string) string { - // Support \n (Unix), \r\n (Windows), and \r (classic Mac) line endings for frontmatter blocks - // (?s) enables DOTALL so . matches newlines; - // ^--- at start, then ... --- at start of line, honoring all three line ending types - // Match zero or more trailing line endings after closing --- (handles both with and without blank lines) - re := regexp.MustCompile(`(?s)^---(?:\r\n|\n|\r)(.*?)(?:\r\n|\n|\r)---(?:\r\n|\n|\r)*`) - return re.ReplaceAllString(content, "") + return reStripFrontmatter.ReplaceAllString(content, "") } func escapeXML(s string) string { diff --git a/pkg/tools/web.go b/pkg/tools/web.go index 44df28215..8ba2a723a 100644 --- a/pkg/tools/web.go +++ b/pkg/tools/web.go @@ -17,6 +17,19 @@ const ( userAgent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36" ) +// Pre-compiled regexes for HTML text extraction +var ( + reScript = regexp.MustCompile(``) + reStyle = regexp.MustCompile(``) + reTags = regexp.MustCompile(`<[^>]+>`) + reWhitespace = regexp.MustCompile(`[^\S\n]+`) + reBlankLines = regexp.MustCompile(`\n{3,}`) + + // DuckDuckGo result extraction + reDDGLink = regexp.MustCompile(`]*class="[^"]*result__a[^"]*"[^>]*href="([^"]+)"[^>]*>([\s\S]*?)`) + reDDGSnippet = regexp.MustCompile(`([\s\S]*?)`) +) + // createHTTPClient creates an HTTP client with optional proxy support func createHTTPClient(proxyURL string, timeout time.Duration) (*http.Client, error) { client := &http.Client{ @@ -251,8 +264,7 @@ func (p *DuckDuckGoSearchProvider) extractResults(html string, count int, query // Try finding the result links directly first, as they are the most critical // Pattern: Title // The previous regex was a bit strict. Let's make it more flexible for attributes order/content - reLink := regexp.MustCompile(`]*class="[^"]*result__a[^"]*"[^>]*href="([^"]+)"[^>]*>([\s\S]*?)`) - matches := reLink.FindAllStringSubmatch(html, count+5) + matches := reDDGLink.FindAllStringSubmatch(html, count+5) if len(matches) == 0 { return fmt.Sprintf("No results found or extraction failed. Query: %s", query), nil @@ -269,8 +281,7 @@ func (p *DuckDuckGoSearchProvider) extractResults(html string, count int, query // A better regex approach: iterate through text and find matches in order // But for now, let's grab all snippets too - reSnippet := regexp.MustCompile(`([\s\S]*?)`) - snippetMatches := reSnippet.FindAllStringSubmatch(html, count+5) + snippetMatches := reDDGSnippet.FindAllStringSubmatch(html, count+5) maxItems := min(len(matches), count) @@ -305,8 +316,7 @@ func (p *DuckDuckGoSearchProvider) extractResults(html string, count int, query } func stripTags(content string) string { - re := regexp.MustCompile(`<[^>]+>`) - return re.ReplaceAllString(content, "") + return reTags.ReplaceAllString(content, "") } type PerplexitySearchProvider struct { @@ -654,19 +664,14 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]any) *ToolRe } func (t *WebFetchTool) extractText(htmlContent string) string { - re := regexp.MustCompile(``) - result := re.ReplaceAllLiteralString(htmlContent, "") - re = regexp.MustCompile(``) - result = re.ReplaceAllLiteralString(result, "") - re = regexp.MustCompile(`<[^>]+>`) - result = re.ReplaceAllLiteralString(result, "") + result := reScript.ReplaceAllLiteralString(htmlContent, "") + result = reStyle.ReplaceAllLiteralString(result, "") + result = reTags.ReplaceAllLiteralString(result, "") result = strings.TrimSpace(result) - re = regexp.MustCompile(`[^\S\n]+`) - result = re.ReplaceAllString(result, " ") - re = regexp.MustCompile(`\n{3,}`) - result = re.ReplaceAllString(result, "\n\n") + result = reWhitespace.ReplaceAllString(result, " ") + result = reBlankLines.ReplaceAllString(result, "\n\n") lines := strings.Split(result, "\n") var cleanLines []string diff --git a/pkg/utils/http_retry.go b/pkg/utils/http_retry.go new file mode 100644 index 000000000..e90fa2129 --- /dev/null +++ b/pkg/utils/http_retry.go @@ -0,0 +1,57 @@ +package utils + +import ( + "context" + "fmt" + "net/http" + "time" +) + +const maxRetries = 3 + +var retryDelayUnit = time.Second + +func shouldRetry(statusCode int) bool { + return statusCode == http.StatusTooManyRequests || + statusCode >= 500 +} + +func DoRequestWithRetry(client *http.Client, req *http.Request) (*http.Response, error) { + var resp *http.Response + var err error + + for i := range maxRetries { + if i > 0 && resp != nil { + resp.Body.Close() + } + + resp, err = client.Do(req) + if err == nil { + if resp.StatusCode == http.StatusOK { + break + } + if !shouldRetry(resp.StatusCode) { + break + } + } + + if i < maxRetries-1 { + if err = sleepWithCtx(req.Context(), retryDelayUnit*time.Duration(i+1)); err != nil { + return nil, fmt.Errorf("failed to sleep: %w", err) + } + } + } + return resp, err +} + +func sleepWithCtx(ctx context.Context, d time.Duration) error { + timer := time.NewTimer(d) + defer timer.Stop() + + select { + case <-ctx.Done(): + return ctx.Err() + case <-timer.C: + return nil + } +} diff --git a/pkg/utils/http_retry_test.go b/pkg/utils/http_retry_test.go new file mode 100644 index 000000000..1c2dbe115 --- /dev/null +++ b/pkg/utils/http_retry_test.go @@ -0,0 +1,118 @@ +package utils + +import ( + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestDoRequestWithRetry(t *testing.T) { + retryDelayUnit = time.Millisecond + t.Cleanup(func() { retryDelayUnit = time.Second }) + + testcases := []struct { + name string + serverBehavior func(*httptest.Server) int + wantSuccess bool + wantAttempts int + }{ + { + name: "success-on-first-attempt", + serverBehavior: func(server *httptest.Server) int { + return 0 + }, + wantSuccess: true, + wantAttempts: 1, + }, + { + name: "fail-all-attempts", + serverBehavior: func(server *httptest.Server) int { + return 4 + }, + wantSuccess: false, + wantAttempts: 3, + }, + } + + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + attempts := 0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempts++ + if attempts <= tc.serverBehavior(nil) { + w.WriteHeader(http.StatusInternalServerError) + return + } + w.WriteHeader(http.StatusOK) + w.Write([]byte("success")) + })) + + t.Cleanup(func() { + server.Close() + }) + + client := &http.Client{Timeout: 5 * time.Second} + req, err := http.NewRequest(http.MethodGet, server.URL, nil) + require.NoError(t, err) + + resp, err := DoRequestWithRetry(client, req) + + if tc.wantSuccess { + require.NoError(t, err) + require.NotNil(t, resp) + assert.Equal(t, http.StatusOK, resp.StatusCode) + resp.Body.Close() + } else { + require.NotNil(t, resp) + assert.Equal(t, http.StatusInternalServerError, resp.StatusCode) + resp.Body.Close() + } + + assert.Equal(t, tc.wantAttempts, attempts) + }) + } +} + +func TestDoRequestWithRetry_Delay(t *testing.T) { + retryDelayUnit = time.Millisecond + t.Cleanup(func() { retryDelayUnit = time.Second }) + + var start time.Time + delays := []time.Duration{} + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if len(delays) == 0 { + delays = append(delays, 0) + w.WriteHeader(http.StatusInternalServerError) + return + } + if len(delays) == 1 { + start = time.Now() + delays = append(delays, 0) + w.WriteHeader(http.StatusInternalServerError) + return + } + if len(delays) == 2 { + elapsed := time.Since(start) + delays = append(delays, elapsed) + w.WriteHeader(http.StatusOK) + w.Write([]byte("success")) + } + })) + defer server.Close() + + client := &http.Client{Timeout: 10 * time.Second} + req, err := http.NewRequest(http.MethodGet, server.URL, nil) + require.NoError(t, err) + + resp, err := DoRequestWithRetry(client, req) + require.NoError(t, err) + require.NotNil(t, resp) + assert.Equal(t, http.StatusOK, resp.StatusCode) + resp.Body.Close() + + assert.GreaterOrEqual(t, delays[2], time.Millisecond) +}