This commit is contained in:
penglp
2026-02-26 19:29:32 +08:00
27 changed files with 648 additions and 159 deletions
+21 -3
View File
@@ -84,7 +84,13 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
if apiBase == "" {
apiBase = getDefaultAPIBase(protocol)
}
return NewHTTPProviderWithMaxTokensField(cfg.APIKey, apiBase, cfg.Proxy, cfg.MaxTokensField), modelID, nil
return NewHTTPProviderWithMaxTokensFieldAndRequestTimeout(
cfg.APIKey,
apiBase,
cfg.Proxy,
cfg.MaxTokensField,
cfg.RequestTimeout,
), modelID, nil
case "openrouter", "groq", "zhipu", "gemini", "nvidia",
"ollama", "moonshot", "shengsuanyun", "deepseek", "cerebras",
@@ -97,7 +103,13 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
if apiBase == "" {
apiBase = getDefaultAPIBase(protocol)
}
return NewHTTPProviderWithMaxTokensField(cfg.APIKey, apiBase, cfg.Proxy, cfg.MaxTokensField), modelID, nil
return NewHTTPProviderWithMaxTokensFieldAndRequestTimeout(
cfg.APIKey,
apiBase,
cfg.Proxy,
cfg.MaxTokensField,
cfg.RequestTimeout,
), modelID, nil
case "anthropic":
if cfg.AuthMethod == "oauth" || cfg.AuthMethod == "token" {
@@ -116,7 +128,13 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
if cfg.APIKey == "" {
return nil, "", fmt.Errorf("api_key is required for anthropic protocol (model: %s)", cfg.Model)
}
return NewHTTPProviderWithMaxTokensField(cfg.APIKey, apiBase, cfg.Proxy, cfg.MaxTokensField), modelID, nil
return NewHTTPProviderWithMaxTokensFieldAndRequestTimeout(
cfg.APIKey,
apiBase,
cfg.Proxy,
cfg.MaxTokensField,
cfg.RequestTimeout,
), modelID, nil
case "antigravity":
return NewAntigravityProvider(), modelID, nil
+43
View File
@@ -6,7 +6,11 @@
package providers
import (
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/sipeed/picoclaw/pkg/config"
)
@@ -247,3 +251,42 @@ func TestCreateProviderFromConfig_EmptyModel(t *testing.T) {
t.Fatal("CreateProviderFromConfig() expected error for empty model")
}
}
func TestCreateProviderFromConfig_RequestTimeoutPropagation(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(1500 * time.Millisecond)
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"choices":[{"message":{"content":"ok"},"finish_reason":"stop"}]}`))
}))
defer server.Close()
cfg := &config.ModelConfig{
ModelName: "test-timeout",
Model: "openai/gpt-4o",
APIBase: server.URL,
RequestTimeout: 1,
}
provider, modelID, err := CreateProviderFromConfig(cfg)
if err != nil {
t.Fatalf("CreateProviderFromConfig() error = %v", err)
}
if modelID != "gpt-4o" {
t.Fatalf("modelID = %q, want %q", modelID, "gpt-4o")
}
_, err = provider.Chat(
t.Context(),
[]Message{{Role: "user", Content: "hi"}},
nil,
modelID,
nil,
)
if err == nil {
t.Fatal("Chat() expected timeout error, got nil")
}
errMsg := err.Error()
if !strings.Contains(errMsg, "context deadline exceeded") && !strings.Contains(errMsg, "Client.Timeout exceeded") {
t.Fatalf("Chat() error = %q, want timeout-related error", errMsg)
}
}
+15 -1
View File
@@ -8,6 +8,7 @@ package providers
import (
"context"
"time"
"github.com/sipeed/picoclaw/pkg/providers/openai_compat"
)
@@ -23,8 +24,21 @@ func NewHTTPProvider(apiKey, apiBase, proxy string) *HTTPProvider {
}
func NewHTTPProviderWithMaxTokensField(apiKey, apiBase, proxy, maxTokensField string) *HTTPProvider {
return NewHTTPProviderWithMaxTokensFieldAndRequestTimeout(apiKey, apiBase, proxy, maxTokensField, 0)
}
func NewHTTPProviderWithMaxTokensFieldAndRequestTimeout(
apiKey, apiBase, proxy, maxTokensField string,
requestTimeoutSeconds int,
) *HTTPProvider {
return &HTTPProvider{
delegate: openai_compat.NewProviderWithMaxTokensField(apiKey, apiBase, proxy, maxTokensField),
delegate: openai_compat.NewProvider(
apiKey,
apiBase,
proxy,
openai_compat.WithMaxTokensField(maxTokensField),
openai_compat.WithRequestTimeout(time.Duration(requestTimeoutSeconds)*time.Second),
),
}
}
+52 -10
View File
@@ -34,13 +34,27 @@ type Provider struct {
httpClient *http.Client
}
func NewProvider(apiKey, apiBase, proxy string) *Provider {
return NewProviderWithMaxTokensField(apiKey, apiBase, proxy, "")
type Option func(*Provider)
const defaultRequestTimeout = 120 * time.Second
func WithMaxTokensField(maxTokensField string) Option {
return func(p *Provider) {
p.maxTokensField = maxTokensField
}
}
func NewProviderWithMaxTokensField(apiKey, apiBase, proxy, maxTokensField string) *Provider {
func WithRequestTimeout(timeout time.Duration) Option {
return func(p *Provider) {
if timeout > 0 {
p.httpClient.Timeout = timeout
}
}
}
func NewProvider(apiKey, apiBase, proxy string, opts ...Option) *Provider {
client := &http.Client{
Timeout: 120 * time.Second,
Timeout: defaultRequestTimeout,
}
if proxy != "" {
@@ -54,12 +68,36 @@ func NewProviderWithMaxTokensField(apiKey, apiBase, proxy, maxTokensField string
}
}
return &Provider{
apiKey: apiKey,
apiBase: strings.TrimRight(apiBase, "/"),
maxTokensField: maxTokensField,
httpClient: client,
p := &Provider{
apiKey: apiKey,
apiBase: strings.TrimRight(apiBase, "/"),
httpClient: client,
}
for _, opt := range opts {
if opt != nil {
opt(p)
}
}
return p
}
func NewProviderWithMaxTokensField(apiKey, apiBase, proxy, maxTokensField string) *Provider {
return NewProvider(apiKey, apiBase, proxy, WithMaxTokensField(maxTokensField))
}
func NewProviderWithMaxTokensFieldAndTimeout(
apiKey, apiBase, proxy, maxTokensField string,
requestTimeoutSeconds int,
) *Provider {
return NewProvider(
apiKey,
apiBase,
proxy,
WithMaxTokensField(maxTokensField),
WithRequestTimeout(time.Duration(requestTimeoutSeconds)*time.Second),
)
}
func (p *Provider) Chat(
@@ -115,8 +153,12 @@ func (p *Provider) Chat(
// with the same key and reuse prefix KV cache across calls.
// The key is typically the agent ID — stable per agent, shared across requests.
// See: https://platform.openai.com/docs/guides/prompt-caching
// Prompt caching is only supported by OpenAI-native endpoints.
// Gemini and other providers reject unknown fields, so skip for non-OpenAI APIs.
if cacheKey, ok := options["prompt_cache_key"].(string); ok && cacheKey != "" {
requestBody["prompt_cache_key"] = cacheKey
if !strings.Contains(p.apiBase, "generativelanguage.googleapis.com") {
requestBody["prompt_cache_key"] = cacheKey
}
}
jsonData, err := json.Marshal(requestBody)
@@ -6,6 +6,7 @@ import (
"net/http/httptest"
"net/url"
"testing"
"time"
)
func TestProviderChat_UsesMaxCompletionTokensForGLM(t *testing.T) {
@@ -325,3 +326,38 @@ func TestNormalizeModel_UsesAPIBase(t *testing.T) {
t.Fatalf("normalizeModel(openrouter) = %q, want %q", got, "openrouter/auto")
}
}
func TestProvider_RequestTimeoutDefault(t *testing.T) {
p := NewProviderWithMaxTokensFieldAndTimeout("key", "https://example.com/v1", "", "", 0)
if p.httpClient.Timeout != defaultRequestTimeout {
t.Fatalf("http timeout = %v, want %v", p.httpClient.Timeout, defaultRequestTimeout)
}
}
func TestProvider_RequestTimeoutOverride(t *testing.T) {
p := NewProviderWithMaxTokensFieldAndTimeout("key", "https://example.com/v1", "", "", 300)
if p.httpClient.Timeout != 300*time.Second {
t.Fatalf("http timeout = %v, want %v", p.httpClient.Timeout, 300*time.Second)
}
}
func TestProvider_FunctionalOptionMaxTokensField(t *testing.T) {
p := NewProvider("key", "https://example.com/v1", "", WithMaxTokensField("max_completion_tokens"))
if p.maxTokensField != "max_completion_tokens" {
t.Fatalf("maxTokensField = %q, want %q", p.maxTokensField, "max_completion_tokens")
}
}
func TestProvider_FunctionalOptionRequestTimeout(t *testing.T) {
p := NewProvider("key", "https://example.com/v1", "", WithRequestTimeout(45*time.Second))
if p.httpClient.Timeout != 45*time.Second {
t.Fatalf("http timeout = %v, want %v", p.httpClient.Timeout, 45*time.Second)
}
}
func TestProvider_FunctionalOptionRequestTimeoutNonPositive(t *testing.T) {
p := NewProvider("key", "https://example.com/v1", "", WithRequestTimeout(-1*time.Second))
if p.httpClient.Timeout != defaultRequestTimeout {
t.Fatalf("http timeout = %v, want %v", p.httpClient.Timeout, defaultRequestTimeout)
}
}