mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
Merge branch 'main' of https://github.com/sipeed/picoclaw
This commit is contained in:
@@ -84,7 +84,13 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
|
||||
if apiBase == "" {
|
||||
apiBase = getDefaultAPIBase(protocol)
|
||||
}
|
||||
return NewHTTPProviderWithMaxTokensField(cfg.APIKey, apiBase, cfg.Proxy, cfg.MaxTokensField), modelID, nil
|
||||
return NewHTTPProviderWithMaxTokensFieldAndRequestTimeout(
|
||||
cfg.APIKey,
|
||||
apiBase,
|
||||
cfg.Proxy,
|
||||
cfg.MaxTokensField,
|
||||
cfg.RequestTimeout,
|
||||
), modelID, nil
|
||||
|
||||
case "openrouter", "groq", "zhipu", "gemini", "nvidia",
|
||||
"ollama", "moonshot", "shengsuanyun", "deepseek", "cerebras",
|
||||
@@ -97,7 +103,13 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
|
||||
if apiBase == "" {
|
||||
apiBase = getDefaultAPIBase(protocol)
|
||||
}
|
||||
return NewHTTPProviderWithMaxTokensField(cfg.APIKey, apiBase, cfg.Proxy, cfg.MaxTokensField), modelID, nil
|
||||
return NewHTTPProviderWithMaxTokensFieldAndRequestTimeout(
|
||||
cfg.APIKey,
|
||||
apiBase,
|
||||
cfg.Proxy,
|
||||
cfg.MaxTokensField,
|
||||
cfg.RequestTimeout,
|
||||
), modelID, nil
|
||||
|
||||
case "anthropic":
|
||||
if cfg.AuthMethod == "oauth" || cfg.AuthMethod == "token" {
|
||||
@@ -116,7 +128,13 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
|
||||
if cfg.APIKey == "" {
|
||||
return nil, "", fmt.Errorf("api_key is required for anthropic protocol (model: %s)", cfg.Model)
|
||||
}
|
||||
return NewHTTPProviderWithMaxTokensField(cfg.APIKey, apiBase, cfg.Proxy, cfg.MaxTokensField), modelID, nil
|
||||
return NewHTTPProviderWithMaxTokensFieldAndRequestTimeout(
|
||||
cfg.APIKey,
|
||||
apiBase,
|
||||
cfg.Proxy,
|
||||
cfg.MaxTokensField,
|
||||
cfg.RequestTimeout,
|
||||
), modelID, nil
|
||||
|
||||
case "antigravity":
|
||||
return NewAntigravityProvider(), modelID, nil
|
||||
|
||||
@@ -6,7 +6,11 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
@@ -247,3 +251,42 @@ func TestCreateProviderFromConfig_EmptyModel(t *testing.T) {
|
||||
t.Fatal("CreateProviderFromConfig() expected error for empty model")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateProviderFromConfig_RequestTimeoutPropagation(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
time.Sleep(1500 * time.Millisecond)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"choices":[{"message":{"content":"ok"},"finish_reason":"stop"}]}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
cfg := &config.ModelConfig{
|
||||
ModelName: "test-timeout",
|
||||
Model: "openai/gpt-4o",
|
||||
APIBase: server.URL,
|
||||
RequestTimeout: 1,
|
||||
}
|
||||
|
||||
provider, modelID, err := CreateProviderFromConfig(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateProviderFromConfig() error = %v", err)
|
||||
}
|
||||
if modelID != "gpt-4o" {
|
||||
t.Fatalf("modelID = %q, want %q", modelID, "gpt-4o")
|
||||
}
|
||||
|
||||
_, err = provider.Chat(
|
||||
t.Context(),
|
||||
[]Message{{Role: "user", Content: "hi"}},
|
||||
nil,
|
||||
modelID,
|
||||
nil,
|
||||
)
|
||||
if err == nil {
|
||||
t.Fatal("Chat() expected timeout error, got nil")
|
||||
}
|
||||
errMsg := err.Error()
|
||||
if !strings.Contains(errMsg, "context deadline exceeded") && !strings.Contains(errMsg, "Client.Timeout exceeded") {
|
||||
t.Fatalf("Chat() error = %q, want timeout-related error", errMsg)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ package providers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/providers/openai_compat"
|
||||
)
|
||||
@@ -23,8 +24,21 @@ func NewHTTPProvider(apiKey, apiBase, proxy string) *HTTPProvider {
|
||||
}
|
||||
|
||||
func NewHTTPProviderWithMaxTokensField(apiKey, apiBase, proxy, maxTokensField string) *HTTPProvider {
|
||||
return NewHTTPProviderWithMaxTokensFieldAndRequestTimeout(apiKey, apiBase, proxy, maxTokensField, 0)
|
||||
}
|
||||
|
||||
func NewHTTPProviderWithMaxTokensFieldAndRequestTimeout(
|
||||
apiKey, apiBase, proxy, maxTokensField string,
|
||||
requestTimeoutSeconds int,
|
||||
) *HTTPProvider {
|
||||
return &HTTPProvider{
|
||||
delegate: openai_compat.NewProviderWithMaxTokensField(apiKey, apiBase, proxy, maxTokensField),
|
||||
delegate: openai_compat.NewProvider(
|
||||
apiKey,
|
||||
apiBase,
|
||||
proxy,
|
||||
openai_compat.WithMaxTokensField(maxTokensField),
|
||||
openai_compat.WithRequestTimeout(time.Duration(requestTimeoutSeconds)*time.Second),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -34,13 +34,27 @@ type Provider struct {
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
func NewProvider(apiKey, apiBase, proxy string) *Provider {
|
||||
return NewProviderWithMaxTokensField(apiKey, apiBase, proxy, "")
|
||||
type Option func(*Provider)
|
||||
|
||||
const defaultRequestTimeout = 120 * time.Second
|
||||
|
||||
func WithMaxTokensField(maxTokensField string) Option {
|
||||
return func(p *Provider) {
|
||||
p.maxTokensField = maxTokensField
|
||||
}
|
||||
}
|
||||
|
||||
func NewProviderWithMaxTokensField(apiKey, apiBase, proxy, maxTokensField string) *Provider {
|
||||
func WithRequestTimeout(timeout time.Duration) Option {
|
||||
return func(p *Provider) {
|
||||
if timeout > 0 {
|
||||
p.httpClient.Timeout = timeout
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func NewProvider(apiKey, apiBase, proxy string, opts ...Option) *Provider {
|
||||
client := &http.Client{
|
||||
Timeout: 120 * time.Second,
|
||||
Timeout: defaultRequestTimeout,
|
||||
}
|
||||
|
||||
if proxy != "" {
|
||||
@@ -54,12 +68,36 @@ func NewProviderWithMaxTokensField(apiKey, apiBase, proxy, maxTokensField string
|
||||
}
|
||||
}
|
||||
|
||||
return &Provider{
|
||||
apiKey: apiKey,
|
||||
apiBase: strings.TrimRight(apiBase, "/"),
|
||||
maxTokensField: maxTokensField,
|
||||
httpClient: client,
|
||||
p := &Provider{
|
||||
apiKey: apiKey,
|
||||
apiBase: strings.TrimRight(apiBase, "/"),
|
||||
httpClient: client,
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
if opt != nil {
|
||||
opt(p)
|
||||
}
|
||||
}
|
||||
|
||||
return p
|
||||
}
|
||||
|
||||
func NewProviderWithMaxTokensField(apiKey, apiBase, proxy, maxTokensField string) *Provider {
|
||||
return NewProvider(apiKey, apiBase, proxy, WithMaxTokensField(maxTokensField))
|
||||
}
|
||||
|
||||
func NewProviderWithMaxTokensFieldAndTimeout(
|
||||
apiKey, apiBase, proxy, maxTokensField string,
|
||||
requestTimeoutSeconds int,
|
||||
) *Provider {
|
||||
return NewProvider(
|
||||
apiKey,
|
||||
apiBase,
|
||||
proxy,
|
||||
WithMaxTokensField(maxTokensField),
|
||||
WithRequestTimeout(time.Duration(requestTimeoutSeconds)*time.Second),
|
||||
)
|
||||
}
|
||||
|
||||
func (p *Provider) Chat(
|
||||
@@ -115,8 +153,12 @@ func (p *Provider) Chat(
|
||||
// with the same key and reuse prefix KV cache across calls.
|
||||
// The key is typically the agent ID — stable per agent, shared across requests.
|
||||
// See: https://platform.openai.com/docs/guides/prompt-caching
|
||||
// Prompt caching is only supported by OpenAI-native endpoints.
|
||||
// Gemini and other providers reject unknown fields, so skip for non-OpenAI APIs.
|
||||
if cacheKey, ok := options["prompt_cache_key"].(string); ok && cacheKey != "" {
|
||||
requestBody["prompt_cache_key"] = cacheKey
|
||||
if !strings.Contains(p.apiBase, "generativelanguage.googleapis.com") {
|
||||
requestBody["prompt_cache_key"] = cacheKey
|
||||
}
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(requestBody)
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestProviderChat_UsesMaxCompletionTokensForGLM(t *testing.T) {
|
||||
@@ -325,3 +326,38 @@ func TestNormalizeModel_UsesAPIBase(t *testing.T) {
|
||||
t.Fatalf("normalizeModel(openrouter) = %q, want %q", got, "openrouter/auto")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProvider_RequestTimeoutDefault(t *testing.T) {
|
||||
p := NewProviderWithMaxTokensFieldAndTimeout("key", "https://example.com/v1", "", "", 0)
|
||||
if p.httpClient.Timeout != defaultRequestTimeout {
|
||||
t.Fatalf("http timeout = %v, want %v", p.httpClient.Timeout, defaultRequestTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProvider_RequestTimeoutOverride(t *testing.T) {
|
||||
p := NewProviderWithMaxTokensFieldAndTimeout("key", "https://example.com/v1", "", "", 300)
|
||||
if p.httpClient.Timeout != 300*time.Second {
|
||||
t.Fatalf("http timeout = %v, want %v", p.httpClient.Timeout, 300*time.Second)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProvider_FunctionalOptionMaxTokensField(t *testing.T) {
|
||||
p := NewProvider("key", "https://example.com/v1", "", WithMaxTokensField("max_completion_tokens"))
|
||||
if p.maxTokensField != "max_completion_tokens" {
|
||||
t.Fatalf("maxTokensField = %q, want %q", p.maxTokensField, "max_completion_tokens")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProvider_FunctionalOptionRequestTimeout(t *testing.T) {
|
||||
p := NewProvider("key", "https://example.com/v1", "", WithRequestTimeout(45*time.Second))
|
||||
if p.httpClient.Timeout != 45*time.Second {
|
||||
t.Fatalf("http timeout = %v, want %v", p.httpClient.Timeout, 45*time.Second)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProvider_FunctionalOptionRequestTimeoutNonPositive(t *testing.T) {
|
||||
p := NewProvider("key", "https://example.com/v1", "", WithRequestTimeout(-1*time.Second))
|
||||
if p.httpClient.Timeout != defaultRequestTimeout {
|
||||
t.Fatalf("http timeout = %v, want %v", p.httpClient.Timeout, defaultRequestTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user