mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
fix(openai_compat): only send prompt_cache_key to OpenAI endpoints (#1353)
Non-OpenAI providers (Mistral, DeepSeek, Groq, etc.) reject unknown request fields with 422 errors. The previous blocklist only excluded Google/Gemini, but the comment already noted this feature is OpenAI-only. Flip to an allowlist so only api.openai.com receives the field. Fixes #1333
This commit is contained in:
@@ -156,9 +156,10 @@ func (p *Provider) Chat(
|
||||
// The key is typically the agent ID — stable per agent, shared across requests.
|
||||
// See: https://platform.openai.com/docs/guides/prompt-caching
|
||||
// Prompt caching is only supported by OpenAI-native endpoints.
|
||||
// Gemini and other providers reject unknown fields, so skip for non-OpenAI APIs.
|
||||
// Non-OpenAI providers (Mistral, Gemini, DeepSeek, etc.) reject unknown
|
||||
// fields with 422 errors, so only include it for OpenAI APIs.
|
||||
if cacheKey, ok := options["prompt_cache_key"].(string); ok && cacheKey != "" {
|
||||
if !strings.Contains(p.apiBase, "generativelanguage.googleapis.com") {
|
||||
if supportsPromptCacheKey(p.apiBase) {
|
||||
requestBody["prompt_cache_key"] = cacheKey
|
||||
}
|
||||
}
|
||||
@@ -476,3 +477,11 @@ func asFloat(v any) (float64, bool) {
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
// supportsPromptCacheKey reports whether the given API base is known to
|
||||
// support the prompt_cache_key request field. Currently only OpenAI's own
|
||||
// API supports this. All other OpenAI-compatible providers (Mistral,
|
||||
// Gemini, DeepSeek, Groq, etc.) reject unknown fields with 422 errors.
|
||||
func supportsPromptCacheKey(apiBase string) bool {
|
||||
return strings.Contains(apiBase, "api.openai.com")
|
||||
}
|
||||
|
||||
@@ -669,6 +669,136 @@ func TestSerializeMessages_MediaWithToolCallID(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderChat_PromptCacheKeySentToOpenAI(t *testing.T) {
|
||||
var requestBody map[string]any
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
resp := map[string]any{
|
||||
"choices": []map[string]any{
|
||||
{
|
||||
"message": map[string]any{"content": "ok"},
|
||||
"finish_reason": "stop",
|
||||
},
|
||||
},
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
// Simulate an OpenAI endpoint by overriding the apiBase after creation.
|
||||
p := NewProvider("key", server.URL, "")
|
||||
p.apiBase = "https://api.openai.com/v1"
|
||||
// Point the HTTP client at our test server instead.
|
||||
p.httpClient = &http.Client{
|
||||
Transport: roundTripperFunc(func(r *http.Request) (*http.Response, error) {
|
||||
// Redirect all requests to the test server.
|
||||
r.URL, _ = url.Parse(server.URL + r.URL.Path)
|
||||
return http.DefaultTransport.RoundTrip(r)
|
||||
}),
|
||||
}
|
||||
|
||||
_, err := p.Chat(
|
||||
t.Context(),
|
||||
[]Message{{Role: "user", Content: "hi"}},
|
||||
nil,
|
||||
"gpt-4o",
|
||||
map[string]any{"prompt_cache_key": "agent-main"},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error = %v", err)
|
||||
}
|
||||
if requestBody["prompt_cache_key"] != "agent-main" {
|
||||
t.Fatalf("prompt_cache_key = %v, want %q", requestBody["prompt_cache_key"], "agent-main")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderChat_PromptCacheKeyOmittedForNonOpenAI(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
apiBase string
|
||||
}{
|
||||
{"mistral", "https://api.mistral.ai/v1"},
|
||||
{"gemini", "https://generativelanguage.googleapis.com/v1beta"},
|
||||
{"deepseek", "https://api.deepseek.com/v1"},
|
||||
{"groq", "https://api.groq.com/openai/v1"},
|
||||
{"minimax", "https://api.minimaxi.com/v1"},
|
||||
{"ollama_local", "http://localhost:11434/v1"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var requestBody map[string]any
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
resp := map[string]any{
|
||||
"choices": []map[string]any{
|
||||
{
|
||||
"message": map[string]any{"content": "ok"},
|
||||
"finish_reason": "stop",
|
||||
},
|
||||
},
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
p := NewProvider("key", server.URL, "")
|
||||
p.apiBase = tt.apiBase
|
||||
p.httpClient = &http.Client{
|
||||
Transport: roundTripperFunc(func(r *http.Request) (*http.Response, error) {
|
||||
r.URL, _ = url.Parse(server.URL + r.URL.Path)
|
||||
return http.DefaultTransport.RoundTrip(r)
|
||||
}),
|
||||
}
|
||||
|
||||
_, err := p.Chat(
|
||||
t.Context(),
|
||||
[]Message{{Role: "user", Content: "hi"}},
|
||||
nil,
|
||||
"test-model",
|
||||
map[string]any{"prompt_cache_key": "agent-main"},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error = %v", err)
|
||||
}
|
||||
if _, exists := requestBody["prompt_cache_key"]; exists {
|
||||
t.Fatalf("prompt_cache_key should NOT be sent to %s, but was included in request", tt.name)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSupportsPromptCacheKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
apiBase string
|
||||
want bool
|
||||
}{
|
||||
{"https://api.openai.com/v1", true},
|
||||
{"https://api.openai.com/v1/", true},
|
||||
{"https://api.mistral.ai/v1", false},
|
||||
{"https://generativelanguage.googleapis.com/v1beta", false},
|
||||
{"https://api.deepseek.com/v1", false},
|
||||
{"https://api.groq.com/openai/v1", false},
|
||||
{"http://localhost:11434/v1", false},
|
||||
{"https://openrouter.ai/api/v1", false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
if got := supportsPromptCacheKey(tt.apiBase); got != tt.want {
|
||||
t.Errorf("supportsPromptCacheKey(%q) = %v, want %v", tt.apiBase, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSerializeMessages_StripsSystemParts(t *testing.T) {
|
||||
messages := []protocoltypes.Message{
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user