fix(openai_compat): only send prompt_cache_key to OpenAI endpoints (#1353)

Non-OpenAI providers (Mistral, DeepSeek, Groq, etc.) reject unknown
request fields with 422 errors. The previous blocklist only excluded
Google/Gemini, but the comment already noted this feature is
OpenAI-only. Flip to an allowlist so only api.openai.com receives
the field.

Fixes #1333
This commit is contained in:
Mahendra Teja
2026-03-11 22:51:54 +05:30
committed by GitHub
parent 9b0a48ac6d
commit 4a80c6f58c
2 changed files with 141 additions and 2 deletions
+11 -2
View File
@@ -156,9 +156,10 @@ func (p *Provider) Chat(
// The key is typically the agent ID — stable per agent, shared across requests.
// See: https://platform.openai.com/docs/guides/prompt-caching
// Prompt caching is only supported by OpenAI-native endpoints.
// Gemini and other providers reject unknown fields, so skip for non-OpenAI APIs.
// Non-OpenAI providers (Mistral, Gemini, DeepSeek, etc.) reject unknown
// fields with 422 errors, so only include it for OpenAI APIs.
if cacheKey, ok := options["prompt_cache_key"].(string); ok && cacheKey != "" {
if !strings.Contains(p.apiBase, "generativelanguage.googleapis.com") {
if supportsPromptCacheKey(p.apiBase) {
requestBody["prompt_cache_key"] = cacheKey
}
}
@@ -476,3 +477,11 @@ func asFloat(v any) (float64, bool) {
return 0, false
}
}
// supportsPromptCacheKey reports whether the given API base is known to
// support the prompt_cache_key request field. Currently only OpenAI's own
// API supports this. All other OpenAI-compatible providers (Mistral,
// Gemini, DeepSeek, Groq, etc.) reject unknown fields with 422 errors.
func supportsPromptCacheKey(apiBase string) bool {
return strings.Contains(apiBase, "api.openai.com")
}
@@ -669,6 +669,136 @@ func TestSerializeMessages_MediaWithToolCallID(t *testing.T) {
}
}
func TestProviderChat_PromptCacheKeySentToOpenAI(t *testing.T) {
var requestBody map[string]any
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
resp := map[string]any{
"choices": []map[string]any{
{
"message": map[string]any{"content": "ok"},
"finish_reason": "stop",
},
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
// Simulate an OpenAI endpoint by overriding the apiBase after creation.
p := NewProvider("key", server.URL, "")
p.apiBase = "https://api.openai.com/v1"
// Point the HTTP client at our test server instead.
p.httpClient = &http.Client{
Transport: roundTripperFunc(func(r *http.Request) (*http.Response, error) {
// Redirect all requests to the test server.
r.URL, _ = url.Parse(server.URL + r.URL.Path)
return http.DefaultTransport.RoundTrip(r)
}),
}
_, err := p.Chat(
t.Context(),
[]Message{{Role: "user", Content: "hi"}},
nil,
"gpt-4o",
map[string]any{"prompt_cache_key": "agent-main"},
)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
if requestBody["prompt_cache_key"] != "agent-main" {
t.Fatalf("prompt_cache_key = %v, want %q", requestBody["prompt_cache_key"], "agent-main")
}
}
func TestProviderChat_PromptCacheKeyOmittedForNonOpenAI(t *testing.T) {
tests := []struct {
name string
apiBase string
}{
{"mistral", "https://api.mistral.ai/v1"},
{"gemini", "https://generativelanguage.googleapis.com/v1beta"},
{"deepseek", "https://api.deepseek.com/v1"},
{"groq", "https://api.groq.com/openai/v1"},
{"minimax", "https://api.minimaxi.com/v1"},
{"ollama_local", "http://localhost:11434/v1"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var requestBody map[string]any
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
resp := map[string]any{
"choices": []map[string]any{
{
"message": map[string]any{"content": "ok"},
"finish_reason": "stop",
},
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
p := NewProvider("key", server.URL, "")
p.apiBase = tt.apiBase
p.httpClient = &http.Client{
Transport: roundTripperFunc(func(r *http.Request) (*http.Response, error) {
r.URL, _ = url.Parse(server.URL + r.URL.Path)
return http.DefaultTransport.RoundTrip(r)
}),
}
_, err := p.Chat(
t.Context(),
[]Message{{Role: "user", Content: "hi"}},
nil,
"test-model",
map[string]any{"prompt_cache_key": "agent-main"},
)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
if _, exists := requestBody["prompt_cache_key"]; exists {
t.Fatalf("prompt_cache_key should NOT be sent to %s, but was included in request", tt.name)
}
})
}
}
func TestSupportsPromptCacheKey(t *testing.T) {
tests := []struct {
apiBase string
want bool
}{
{"https://api.openai.com/v1", true},
{"https://api.openai.com/v1/", true},
{"https://api.mistral.ai/v1", false},
{"https://generativelanguage.googleapis.com/v1beta", false},
{"https://api.deepseek.com/v1", false},
{"https://api.groq.com/openai/v1", false},
{"http://localhost:11434/v1", false},
{"https://openrouter.ai/api/v1", false},
}
for _, tt := range tests {
if got := supportsPromptCacheKey(tt.apiBase); got != tt.want {
t.Errorf("supportsPromptCacheKey(%q) = %v, want %v", tt.apiBase, got, tt.want)
}
}
}
func TestSerializeMessages_StripsSystemParts(t *testing.T) {
messages := []protocoltypes.Message{
{