mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
fix(providers): support per-model request_timeout in model_list (#733)
* fix(providers): support per-model request_timeout in model_list * fix(lint): format provider constructors for golines * refactor(providers): adopt functional options and preserve timeout migration * docs(readme): sync request_timeout guidance across translated docs --------- Co-authored-by: Yiliu <yiliu@affiliate-guide.com>
This commit is contained in:
@@ -371,11 +371,12 @@ func (p ProvidersConfig) MarshalJSON() ([]byte, error) {
|
||||
}
|
||||
|
||||
type ProviderConfig struct {
|
||||
APIKey string `json:"api_key" env:"PICOCLAW_PROVIDERS_{{.Name}}_API_KEY"`
|
||||
APIBase string `json:"api_base" env:"PICOCLAW_PROVIDERS_{{.Name}}_API_BASE"`
|
||||
Proxy string `json:"proxy,omitempty" env:"PICOCLAW_PROVIDERS_{{.Name}}_PROXY"`
|
||||
AuthMethod string `json:"auth_method,omitempty" env:"PICOCLAW_PROVIDERS_{{.Name}}_AUTH_METHOD"`
|
||||
ConnectMode string `json:"connect_mode,omitempty" env:"PICOCLAW_PROVIDERS_{{.Name}}_CONNECT_MODE"` // only for Github Copilot, `stdio` or `grpc`
|
||||
APIKey string `json:"api_key" env:"PICOCLAW_PROVIDERS_{{.Name}}_API_KEY"`
|
||||
APIBase string `json:"api_base" env:"PICOCLAW_PROVIDERS_{{.Name}}_API_BASE"`
|
||||
Proxy string `json:"proxy,omitempty" env:"PICOCLAW_PROVIDERS_{{.Name}}_PROXY"`
|
||||
RequestTimeout int `json:"request_timeout,omitempty" env:"PICOCLAW_PROVIDERS_{{.Name}}_REQUEST_TIMEOUT"`
|
||||
AuthMethod string `json:"auth_method,omitempty" env:"PICOCLAW_PROVIDERS_{{.Name}}_AUTH_METHOD"`
|
||||
ConnectMode string `json:"connect_mode,omitempty" env:"PICOCLAW_PROVIDERS_{{.Name}}_CONNECT_MODE"` // only for Github Copilot, `stdio` or `grpc`
|
||||
}
|
||||
|
||||
type OpenAIProviderConfig struct {
|
||||
@@ -406,6 +407,7 @@ type ModelConfig struct {
|
||||
// Optional optimizations
|
||||
RPM int `json:"rpm,omitempty"` // Requests per minute limit
|
||||
MaxTokensField string `json:"max_tokens_field,omitempty"` // Field name for max tokens (e.g., "max_completion_tokens")
|
||||
RequestTimeout int `json:"request_timeout,omitempty"`
|
||||
}
|
||||
|
||||
// Validate checks if the ModelConfig has all required fields.
|
||||
|
||||
+98
-82
@@ -60,12 +60,13 @@ func ConvertProvidersToModelList(cfg *Config) []ModelConfig {
|
||||
return ModelConfig{}, false
|
||||
}
|
||||
return ModelConfig{
|
||||
ModelName: "openai",
|
||||
Model: "openai/gpt-5.2",
|
||||
APIKey: p.OpenAI.APIKey,
|
||||
APIBase: p.OpenAI.APIBase,
|
||||
Proxy: p.OpenAI.Proxy,
|
||||
AuthMethod: p.OpenAI.AuthMethod,
|
||||
ModelName: "openai",
|
||||
Model: "openai/gpt-5.2",
|
||||
APIKey: p.OpenAI.APIKey,
|
||||
APIBase: p.OpenAI.APIBase,
|
||||
Proxy: p.OpenAI.Proxy,
|
||||
RequestTimeout: p.OpenAI.RequestTimeout,
|
||||
AuthMethod: p.OpenAI.AuthMethod,
|
||||
}, true
|
||||
},
|
||||
},
|
||||
@@ -77,12 +78,13 @@ func ConvertProvidersToModelList(cfg *Config) []ModelConfig {
|
||||
return ModelConfig{}, false
|
||||
}
|
||||
return ModelConfig{
|
||||
ModelName: "anthropic",
|
||||
Model: "anthropic/claude-sonnet-4.6",
|
||||
APIKey: p.Anthropic.APIKey,
|
||||
APIBase: p.Anthropic.APIBase,
|
||||
Proxy: p.Anthropic.Proxy,
|
||||
AuthMethod: p.Anthropic.AuthMethod,
|
||||
ModelName: "anthropic",
|
||||
Model: "anthropic/claude-sonnet-4.6",
|
||||
APIKey: p.Anthropic.APIKey,
|
||||
APIBase: p.Anthropic.APIBase,
|
||||
Proxy: p.Anthropic.Proxy,
|
||||
RequestTimeout: p.Anthropic.RequestTimeout,
|
||||
AuthMethod: p.Anthropic.AuthMethod,
|
||||
}, true
|
||||
},
|
||||
},
|
||||
@@ -94,11 +96,12 @@ func ConvertProvidersToModelList(cfg *Config) []ModelConfig {
|
||||
return ModelConfig{}, false
|
||||
}
|
||||
return ModelConfig{
|
||||
ModelName: "openrouter",
|
||||
Model: "openrouter/auto",
|
||||
APIKey: p.OpenRouter.APIKey,
|
||||
APIBase: p.OpenRouter.APIBase,
|
||||
Proxy: p.OpenRouter.Proxy,
|
||||
ModelName: "openrouter",
|
||||
Model: "openrouter/auto",
|
||||
APIKey: p.OpenRouter.APIKey,
|
||||
APIBase: p.OpenRouter.APIBase,
|
||||
Proxy: p.OpenRouter.Proxy,
|
||||
RequestTimeout: p.OpenRouter.RequestTimeout,
|
||||
}, true
|
||||
},
|
||||
},
|
||||
@@ -110,11 +113,12 @@ func ConvertProvidersToModelList(cfg *Config) []ModelConfig {
|
||||
return ModelConfig{}, false
|
||||
}
|
||||
return ModelConfig{
|
||||
ModelName: "groq",
|
||||
Model: "groq/llama-3.1-70b-versatile",
|
||||
APIKey: p.Groq.APIKey,
|
||||
APIBase: p.Groq.APIBase,
|
||||
Proxy: p.Groq.Proxy,
|
||||
ModelName: "groq",
|
||||
Model: "groq/llama-3.1-70b-versatile",
|
||||
APIKey: p.Groq.APIKey,
|
||||
APIBase: p.Groq.APIBase,
|
||||
Proxy: p.Groq.Proxy,
|
||||
RequestTimeout: p.Groq.RequestTimeout,
|
||||
}, true
|
||||
},
|
||||
},
|
||||
@@ -126,11 +130,12 @@ func ConvertProvidersToModelList(cfg *Config) []ModelConfig {
|
||||
return ModelConfig{}, false
|
||||
}
|
||||
return ModelConfig{
|
||||
ModelName: "zhipu",
|
||||
Model: "zhipu/glm-4",
|
||||
APIKey: p.Zhipu.APIKey,
|
||||
APIBase: p.Zhipu.APIBase,
|
||||
Proxy: p.Zhipu.Proxy,
|
||||
ModelName: "zhipu",
|
||||
Model: "zhipu/glm-4",
|
||||
APIKey: p.Zhipu.APIKey,
|
||||
APIBase: p.Zhipu.APIBase,
|
||||
Proxy: p.Zhipu.Proxy,
|
||||
RequestTimeout: p.Zhipu.RequestTimeout,
|
||||
}, true
|
||||
},
|
||||
},
|
||||
@@ -142,11 +147,12 @@ func ConvertProvidersToModelList(cfg *Config) []ModelConfig {
|
||||
return ModelConfig{}, false
|
||||
}
|
||||
return ModelConfig{
|
||||
ModelName: "vllm",
|
||||
Model: "vllm/auto",
|
||||
APIKey: p.VLLM.APIKey,
|
||||
APIBase: p.VLLM.APIBase,
|
||||
Proxy: p.VLLM.Proxy,
|
||||
ModelName: "vllm",
|
||||
Model: "vllm/auto",
|
||||
APIKey: p.VLLM.APIKey,
|
||||
APIBase: p.VLLM.APIBase,
|
||||
Proxy: p.VLLM.Proxy,
|
||||
RequestTimeout: p.VLLM.RequestTimeout,
|
||||
}, true
|
||||
},
|
||||
},
|
||||
@@ -158,11 +164,12 @@ func ConvertProvidersToModelList(cfg *Config) []ModelConfig {
|
||||
return ModelConfig{}, false
|
||||
}
|
||||
return ModelConfig{
|
||||
ModelName: "gemini",
|
||||
Model: "gemini/gemini-pro",
|
||||
APIKey: p.Gemini.APIKey,
|
||||
APIBase: p.Gemini.APIBase,
|
||||
Proxy: p.Gemini.Proxy,
|
||||
ModelName: "gemini",
|
||||
Model: "gemini/gemini-pro",
|
||||
APIKey: p.Gemini.APIKey,
|
||||
APIBase: p.Gemini.APIBase,
|
||||
Proxy: p.Gemini.Proxy,
|
||||
RequestTimeout: p.Gemini.RequestTimeout,
|
||||
}, true
|
||||
},
|
||||
},
|
||||
@@ -174,11 +181,12 @@ func ConvertProvidersToModelList(cfg *Config) []ModelConfig {
|
||||
return ModelConfig{}, false
|
||||
}
|
||||
return ModelConfig{
|
||||
ModelName: "nvidia",
|
||||
Model: "nvidia/meta/llama-3.1-8b-instruct",
|
||||
APIKey: p.Nvidia.APIKey,
|
||||
APIBase: p.Nvidia.APIBase,
|
||||
Proxy: p.Nvidia.Proxy,
|
||||
ModelName: "nvidia",
|
||||
Model: "nvidia/meta/llama-3.1-8b-instruct",
|
||||
APIKey: p.Nvidia.APIKey,
|
||||
APIBase: p.Nvidia.APIBase,
|
||||
Proxy: p.Nvidia.Proxy,
|
||||
RequestTimeout: p.Nvidia.RequestTimeout,
|
||||
}, true
|
||||
},
|
||||
},
|
||||
@@ -190,11 +198,12 @@ func ConvertProvidersToModelList(cfg *Config) []ModelConfig {
|
||||
return ModelConfig{}, false
|
||||
}
|
||||
return ModelConfig{
|
||||
ModelName: "ollama",
|
||||
Model: "ollama/llama3",
|
||||
APIKey: p.Ollama.APIKey,
|
||||
APIBase: p.Ollama.APIBase,
|
||||
Proxy: p.Ollama.Proxy,
|
||||
ModelName: "ollama",
|
||||
Model: "ollama/llama3",
|
||||
APIKey: p.Ollama.APIKey,
|
||||
APIBase: p.Ollama.APIBase,
|
||||
Proxy: p.Ollama.Proxy,
|
||||
RequestTimeout: p.Ollama.RequestTimeout,
|
||||
}, true
|
||||
},
|
||||
},
|
||||
@@ -206,11 +215,12 @@ func ConvertProvidersToModelList(cfg *Config) []ModelConfig {
|
||||
return ModelConfig{}, false
|
||||
}
|
||||
return ModelConfig{
|
||||
ModelName: "moonshot",
|
||||
Model: "moonshot/kimi",
|
||||
APIKey: p.Moonshot.APIKey,
|
||||
APIBase: p.Moonshot.APIBase,
|
||||
Proxy: p.Moonshot.Proxy,
|
||||
ModelName: "moonshot",
|
||||
Model: "moonshot/kimi",
|
||||
APIKey: p.Moonshot.APIKey,
|
||||
APIBase: p.Moonshot.APIBase,
|
||||
Proxy: p.Moonshot.Proxy,
|
||||
RequestTimeout: p.Moonshot.RequestTimeout,
|
||||
}, true
|
||||
},
|
||||
},
|
||||
@@ -222,11 +232,12 @@ func ConvertProvidersToModelList(cfg *Config) []ModelConfig {
|
||||
return ModelConfig{}, false
|
||||
}
|
||||
return ModelConfig{
|
||||
ModelName: "shengsuanyun",
|
||||
Model: "shengsuanyun/auto",
|
||||
APIKey: p.ShengSuanYun.APIKey,
|
||||
APIBase: p.ShengSuanYun.APIBase,
|
||||
Proxy: p.ShengSuanYun.Proxy,
|
||||
ModelName: "shengsuanyun",
|
||||
Model: "shengsuanyun/auto",
|
||||
APIKey: p.ShengSuanYun.APIKey,
|
||||
APIBase: p.ShengSuanYun.APIBase,
|
||||
Proxy: p.ShengSuanYun.Proxy,
|
||||
RequestTimeout: p.ShengSuanYun.RequestTimeout,
|
||||
}, true
|
||||
},
|
||||
},
|
||||
@@ -238,11 +249,12 @@ func ConvertProvidersToModelList(cfg *Config) []ModelConfig {
|
||||
return ModelConfig{}, false
|
||||
}
|
||||
return ModelConfig{
|
||||
ModelName: "deepseek",
|
||||
Model: "deepseek/deepseek-chat",
|
||||
APIKey: p.DeepSeek.APIKey,
|
||||
APIBase: p.DeepSeek.APIBase,
|
||||
Proxy: p.DeepSeek.Proxy,
|
||||
ModelName: "deepseek",
|
||||
Model: "deepseek/deepseek-chat",
|
||||
APIKey: p.DeepSeek.APIKey,
|
||||
APIBase: p.DeepSeek.APIBase,
|
||||
Proxy: p.DeepSeek.Proxy,
|
||||
RequestTimeout: p.DeepSeek.RequestTimeout,
|
||||
}, true
|
||||
},
|
||||
},
|
||||
@@ -254,11 +266,12 @@ func ConvertProvidersToModelList(cfg *Config) []ModelConfig {
|
||||
return ModelConfig{}, false
|
||||
}
|
||||
return ModelConfig{
|
||||
ModelName: "cerebras",
|
||||
Model: "cerebras/llama-3.3-70b",
|
||||
APIKey: p.Cerebras.APIKey,
|
||||
APIBase: p.Cerebras.APIBase,
|
||||
Proxy: p.Cerebras.Proxy,
|
||||
ModelName: "cerebras",
|
||||
Model: "cerebras/llama-3.3-70b",
|
||||
APIKey: p.Cerebras.APIKey,
|
||||
APIBase: p.Cerebras.APIBase,
|
||||
Proxy: p.Cerebras.Proxy,
|
||||
RequestTimeout: p.Cerebras.RequestTimeout,
|
||||
}, true
|
||||
},
|
||||
},
|
||||
@@ -270,11 +283,12 @@ func ConvertProvidersToModelList(cfg *Config) []ModelConfig {
|
||||
return ModelConfig{}, false
|
||||
}
|
||||
return ModelConfig{
|
||||
ModelName: "volcengine",
|
||||
Model: "volcengine/doubao-pro",
|
||||
APIKey: p.VolcEngine.APIKey,
|
||||
APIBase: p.VolcEngine.APIBase,
|
||||
Proxy: p.VolcEngine.Proxy,
|
||||
ModelName: "volcengine",
|
||||
Model: "volcengine/doubao-pro",
|
||||
APIKey: p.VolcEngine.APIKey,
|
||||
APIBase: p.VolcEngine.APIBase,
|
||||
Proxy: p.VolcEngine.Proxy,
|
||||
RequestTimeout: p.VolcEngine.RequestTimeout,
|
||||
}, true
|
||||
},
|
||||
},
|
||||
@@ -316,11 +330,12 @@ func ConvertProvidersToModelList(cfg *Config) []ModelConfig {
|
||||
return ModelConfig{}, false
|
||||
}
|
||||
return ModelConfig{
|
||||
ModelName: "qwen",
|
||||
Model: "qwen/qwen-max",
|
||||
APIKey: p.Qwen.APIKey,
|
||||
APIBase: p.Qwen.APIBase,
|
||||
Proxy: p.Qwen.Proxy,
|
||||
ModelName: "qwen",
|
||||
Model: "qwen/qwen-max",
|
||||
APIKey: p.Qwen.APIKey,
|
||||
APIBase: p.Qwen.APIBase,
|
||||
Proxy: p.Qwen.Proxy,
|
||||
RequestTimeout: p.Qwen.RequestTimeout,
|
||||
}, true
|
||||
},
|
||||
},
|
||||
@@ -332,11 +347,12 @@ func ConvertProvidersToModelList(cfg *Config) []ModelConfig {
|
||||
return ModelConfig{}, false
|
||||
}
|
||||
return ModelConfig{
|
||||
ModelName: "mistral",
|
||||
Model: "mistral/mistral-small-latest",
|
||||
APIKey: p.Mistral.APIKey,
|
||||
APIBase: p.Mistral.APIBase,
|
||||
Proxy: p.Mistral.Proxy,
|
||||
ModelName: "mistral",
|
||||
Model: "mistral/mistral-small-latest",
|
||||
APIKey: p.Mistral.APIKey,
|
||||
APIBase: p.Mistral.APIBase,
|
||||
Proxy: p.Mistral.Proxy,
|
||||
RequestTimeout: p.Mistral.RequestTimeout,
|
||||
}, true
|
||||
},
|
||||
},
|
||||
|
||||
@@ -166,6 +166,27 @@ func TestConvertProvidersToModelList_Proxy(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertProvidersToModelList_RequestTimeout(t *testing.T) {
|
||||
cfg := &Config{
|
||||
Providers: ProvidersConfig{
|
||||
Ollama: ProviderConfig{
|
||||
APIKey: "ollama-key",
|
||||
RequestTimeout: 300,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result := ConvertProvidersToModelList(cfg)
|
||||
|
||||
if len(result) != 1 {
|
||||
t.Fatalf("len(result) = %d, want 1", len(result))
|
||||
}
|
||||
|
||||
if result[0].RequestTimeout != 300 {
|
||||
t.Errorf("RequestTimeout = %d, want %d", result[0].RequestTimeout, 300)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertProvidersToModelList_AuthMethod(t *testing.T) {
|
||||
cfg := &Config{
|
||||
Providers: ProvidersConfig{
|
||||
|
||||
@@ -365,3 +365,38 @@ func TestConfig_ValidateModelList(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelConfig_RequestTimeoutParsing(t *testing.T) {
|
||||
jsonData := `{
|
||||
"model_name": "slow-local",
|
||||
"model": "openai/local-model",
|
||||
"api_base": "http://localhost:11434/v1",
|
||||
"request_timeout": 300
|
||||
}`
|
||||
|
||||
var cfg ModelConfig
|
||||
if err := json.Unmarshal([]byte(jsonData), &cfg); err != nil {
|
||||
t.Fatalf("Unmarshal() error = %v", err)
|
||||
}
|
||||
|
||||
if cfg.RequestTimeout != 300 {
|
||||
t.Fatalf("RequestTimeout = %d, want 300", cfg.RequestTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelConfig_RequestTimeoutDefaultZeroValue(t *testing.T) {
|
||||
jsonData := `{
|
||||
"model_name": "default-timeout",
|
||||
"model": "openai/gpt-4o",
|
||||
"api_key": "test-key"
|
||||
}`
|
||||
|
||||
var cfg ModelConfig
|
||||
if err := json.Unmarshal([]byte(jsonData), &cfg); err != nil {
|
||||
t.Fatalf("Unmarshal() error = %v", err)
|
||||
}
|
||||
|
||||
if cfg.RequestTimeout != 0 {
|
||||
t.Fatalf("RequestTimeout = %d, want 0", cfg.RequestTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -84,7 +84,13 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
|
||||
if apiBase == "" {
|
||||
apiBase = getDefaultAPIBase(protocol)
|
||||
}
|
||||
return NewHTTPProviderWithMaxTokensField(cfg.APIKey, apiBase, cfg.Proxy, cfg.MaxTokensField), modelID, nil
|
||||
return NewHTTPProviderWithMaxTokensFieldAndRequestTimeout(
|
||||
cfg.APIKey,
|
||||
apiBase,
|
||||
cfg.Proxy,
|
||||
cfg.MaxTokensField,
|
||||
cfg.RequestTimeout,
|
||||
), modelID, nil
|
||||
|
||||
case "openrouter", "groq", "zhipu", "gemini", "nvidia",
|
||||
"ollama", "moonshot", "shengsuanyun", "deepseek", "cerebras",
|
||||
@@ -97,7 +103,13 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
|
||||
if apiBase == "" {
|
||||
apiBase = getDefaultAPIBase(protocol)
|
||||
}
|
||||
return NewHTTPProviderWithMaxTokensField(cfg.APIKey, apiBase, cfg.Proxy, cfg.MaxTokensField), modelID, nil
|
||||
return NewHTTPProviderWithMaxTokensFieldAndRequestTimeout(
|
||||
cfg.APIKey,
|
||||
apiBase,
|
||||
cfg.Proxy,
|
||||
cfg.MaxTokensField,
|
||||
cfg.RequestTimeout,
|
||||
), modelID, nil
|
||||
|
||||
case "anthropic":
|
||||
if cfg.AuthMethod == "oauth" || cfg.AuthMethod == "token" {
|
||||
@@ -116,7 +128,13 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
|
||||
if cfg.APIKey == "" {
|
||||
return nil, "", fmt.Errorf("api_key is required for anthropic protocol (model: %s)", cfg.Model)
|
||||
}
|
||||
return NewHTTPProviderWithMaxTokensField(cfg.APIKey, apiBase, cfg.Proxy, cfg.MaxTokensField), modelID, nil
|
||||
return NewHTTPProviderWithMaxTokensFieldAndRequestTimeout(
|
||||
cfg.APIKey,
|
||||
apiBase,
|
||||
cfg.Proxy,
|
||||
cfg.MaxTokensField,
|
||||
cfg.RequestTimeout,
|
||||
), modelID, nil
|
||||
|
||||
case "antigravity":
|
||||
return NewAntigravityProvider(), modelID, nil
|
||||
|
||||
@@ -6,7 +6,11 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
@@ -247,3 +251,42 @@ func TestCreateProviderFromConfig_EmptyModel(t *testing.T) {
|
||||
t.Fatal("CreateProviderFromConfig() expected error for empty model")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateProviderFromConfig_RequestTimeoutPropagation(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
time.Sleep(1500 * time.Millisecond)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"choices":[{"message":{"content":"ok"},"finish_reason":"stop"}]}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
cfg := &config.ModelConfig{
|
||||
ModelName: "test-timeout",
|
||||
Model: "openai/gpt-4o",
|
||||
APIBase: server.URL,
|
||||
RequestTimeout: 1,
|
||||
}
|
||||
|
||||
provider, modelID, err := CreateProviderFromConfig(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateProviderFromConfig() error = %v", err)
|
||||
}
|
||||
if modelID != "gpt-4o" {
|
||||
t.Fatalf("modelID = %q, want %q", modelID, "gpt-4o")
|
||||
}
|
||||
|
||||
_, err = provider.Chat(
|
||||
t.Context(),
|
||||
[]Message{{Role: "user", Content: "hi"}},
|
||||
nil,
|
||||
modelID,
|
||||
nil,
|
||||
)
|
||||
if err == nil {
|
||||
t.Fatal("Chat() expected timeout error, got nil")
|
||||
}
|
||||
errMsg := err.Error()
|
||||
if !strings.Contains(errMsg, "context deadline exceeded") && !strings.Contains(errMsg, "Client.Timeout exceeded") {
|
||||
t.Fatalf("Chat() error = %q, want timeout-related error", errMsg)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ package providers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/providers/openai_compat"
|
||||
)
|
||||
@@ -23,8 +24,21 @@ func NewHTTPProvider(apiKey, apiBase, proxy string) *HTTPProvider {
|
||||
}
|
||||
|
||||
func NewHTTPProviderWithMaxTokensField(apiKey, apiBase, proxy, maxTokensField string) *HTTPProvider {
|
||||
return NewHTTPProviderWithMaxTokensFieldAndRequestTimeout(apiKey, apiBase, proxy, maxTokensField, 0)
|
||||
}
|
||||
|
||||
func NewHTTPProviderWithMaxTokensFieldAndRequestTimeout(
|
||||
apiKey, apiBase, proxy, maxTokensField string,
|
||||
requestTimeoutSeconds int,
|
||||
) *HTTPProvider {
|
||||
return &HTTPProvider{
|
||||
delegate: openai_compat.NewProviderWithMaxTokensField(apiKey, apiBase, proxy, maxTokensField),
|
||||
delegate: openai_compat.NewProvider(
|
||||
apiKey,
|
||||
apiBase,
|
||||
proxy,
|
||||
openai_compat.WithMaxTokensField(maxTokensField),
|
||||
openai_compat.WithRequestTimeout(time.Duration(requestTimeoutSeconds)*time.Second),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -34,13 +34,27 @@ type Provider struct {
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
func NewProvider(apiKey, apiBase, proxy string) *Provider {
|
||||
return NewProviderWithMaxTokensField(apiKey, apiBase, proxy, "")
|
||||
type Option func(*Provider)
|
||||
|
||||
const defaultRequestTimeout = 120 * time.Second
|
||||
|
||||
func WithMaxTokensField(maxTokensField string) Option {
|
||||
return func(p *Provider) {
|
||||
p.maxTokensField = maxTokensField
|
||||
}
|
||||
}
|
||||
|
||||
func NewProviderWithMaxTokensField(apiKey, apiBase, proxy, maxTokensField string) *Provider {
|
||||
func WithRequestTimeout(timeout time.Duration) Option {
|
||||
return func(p *Provider) {
|
||||
if timeout > 0 {
|
||||
p.httpClient.Timeout = timeout
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func NewProvider(apiKey, apiBase, proxy string, opts ...Option) *Provider {
|
||||
client := &http.Client{
|
||||
Timeout: 120 * time.Second,
|
||||
Timeout: defaultRequestTimeout,
|
||||
}
|
||||
|
||||
if proxy != "" {
|
||||
@@ -54,12 +68,36 @@ func NewProviderWithMaxTokensField(apiKey, apiBase, proxy, maxTokensField string
|
||||
}
|
||||
}
|
||||
|
||||
return &Provider{
|
||||
apiKey: apiKey,
|
||||
apiBase: strings.TrimRight(apiBase, "/"),
|
||||
maxTokensField: maxTokensField,
|
||||
httpClient: client,
|
||||
p := &Provider{
|
||||
apiKey: apiKey,
|
||||
apiBase: strings.TrimRight(apiBase, "/"),
|
||||
httpClient: client,
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
if opt != nil {
|
||||
opt(p)
|
||||
}
|
||||
}
|
||||
|
||||
return p
|
||||
}
|
||||
|
||||
func NewProviderWithMaxTokensField(apiKey, apiBase, proxy, maxTokensField string) *Provider {
|
||||
return NewProvider(apiKey, apiBase, proxy, WithMaxTokensField(maxTokensField))
|
||||
}
|
||||
|
||||
func NewProviderWithMaxTokensFieldAndTimeout(
|
||||
apiKey, apiBase, proxy, maxTokensField string,
|
||||
requestTimeoutSeconds int,
|
||||
) *Provider {
|
||||
return NewProvider(
|
||||
apiKey,
|
||||
apiBase,
|
||||
proxy,
|
||||
WithMaxTokensField(maxTokensField),
|
||||
WithRequestTimeout(time.Duration(requestTimeoutSeconds)*time.Second),
|
||||
)
|
||||
}
|
||||
|
||||
func (p *Provider) Chat(
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestProviderChat_UsesMaxCompletionTokensForGLM(t *testing.T) {
|
||||
@@ -325,3 +326,38 @@ func TestNormalizeModel_UsesAPIBase(t *testing.T) {
|
||||
t.Fatalf("normalizeModel(openrouter) = %q, want %q", got, "openrouter/auto")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProvider_RequestTimeoutDefault(t *testing.T) {
|
||||
p := NewProviderWithMaxTokensFieldAndTimeout("key", "https://example.com/v1", "", "", 0)
|
||||
if p.httpClient.Timeout != defaultRequestTimeout {
|
||||
t.Fatalf("http timeout = %v, want %v", p.httpClient.Timeout, defaultRequestTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProvider_RequestTimeoutOverride(t *testing.T) {
|
||||
p := NewProviderWithMaxTokensFieldAndTimeout("key", "https://example.com/v1", "", "", 300)
|
||||
if p.httpClient.Timeout != 300*time.Second {
|
||||
t.Fatalf("http timeout = %v, want %v", p.httpClient.Timeout, 300*time.Second)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProvider_FunctionalOptionMaxTokensField(t *testing.T) {
|
||||
p := NewProvider("key", "https://example.com/v1", "", WithMaxTokensField("max_completion_tokens"))
|
||||
if p.maxTokensField != "max_completion_tokens" {
|
||||
t.Fatalf("maxTokensField = %q, want %q", p.maxTokensField, "max_completion_tokens")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProvider_FunctionalOptionRequestTimeout(t *testing.T) {
|
||||
p := NewProvider("key", "https://example.com/v1", "", WithRequestTimeout(45*time.Second))
|
||||
if p.httpClient.Timeout != 45*time.Second {
|
||||
t.Fatalf("http timeout = %v, want %v", p.httpClient.Timeout, 45*time.Second)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProvider_FunctionalOptionRequestTimeoutNonPositive(t *testing.T) {
|
||||
p := NewProvider("key", "https://example.com/v1", "", WithRequestTimeout(-1*time.Second))
|
||||
if p.httpClient.Timeout != defaultRequestTimeout {
|
||||
t.Fatalf("http timeout = %v, want %v", p.httpClient.Timeout, defaultRequestTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user