From 4a80c6f58c1b14d69e3c4d5da199c746eb59a881 Mon Sep 17 00:00:00 2001 From: Mahendra Teja <109858254+mahendrateja95@users.noreply.github.com> Date: Wed, 11 Mar 2026 22:51:54 +0530 Subject: [PATCH] fix(openai_compat): only send prompt_cache_key to OpenAI endpoints (#1353) Non-OpenAI providers (Mistral, DeepSeek, Groq, etc.) reject unknown request fields with 422 errors. The previous blocklist only excluded Google/Gemini, but the comment already noted this feature is OpenAI-only. Flip to an allowlist so only api.openai.com receives the field. Fixes #1333 --- pkg/providers/openai_compat/provider.go | 13 +- pkg/providers/openai_compat/provider_test.go | 130 +++++++++++++++++++ 2 files changed, 141 insertions(+), 2 deletions(-) diff --git a/pkg/providers/openai_compat/provider.go b/pkg/providers/openai_compat/provider.go index 0e8db7409..aa4fa9e6d 100644 --- a/pkg/providers/openai_compat/provider.go +++ b/pkg/providers/openai_compat/provider.go @@ -156,9 +156,10 @@ func (p *Provider) Chat( // The key is typically the agent ID — stable per agent, shared across requests. // See: https://platform.openai.com/docs/guides/prompt-caching // Prompt caching is only supported by OpenAI-native endpoints. - // Gemini and other providers reject unknown fields, so skip for non-OpenAI APIs. + // Non-OpenAI providers (Mistral, Gemini, DeepSeek, etc.) reject unknown + // fields with 422 errors, so only include it for OpenAI APIs. if cacheKey, ok := options["prompt_cache_key"].(string); ok && cacheKey != "" { - if !strings.Contains(p.apiBase, "generativelanguage.googleapis.com") { + if supportsPromptCacheKey(p.apiBase) { requestBody["prompt_cache_key"] = cacheKey } } @@ -476,3 +477,11 @@ func asFloat(v any) (float64, bool) { return 0, false } } + +// supportsPromptCacheKey reports whether the given API base is known to +// support the prompt_cache_key request field. Currently only OpenAI's own +// API supports this. All other OpenAI-compatible providers (Mistral, +// Gemini, DeepSeek, Groq, etc.) reject unknown fields with 422 errors. +func supportsPromptCacheKey(apiBase string) bool { + return strings.Contains(apiBase, "api.openai.com") +} diff --git a/pkg/providers/openai_compat/provider_test.go b/pkg/providers/openai_compat/provider_test.go index 9a3a7acc5..5581146fe 100644 --- a/pkg/providers/openai_compat/provider_test.go +++ b/pkg/providers/openai_compat/provider_test.go @@ -669,6 +669,136 @@ func TestSerializeMessages_MediaWithToolCallID(t *testing.T) { } } +func TestProviderChat_PromptCacheKeySentToOpenAI(t *testing.T) { + var requestBody map[string]any + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + resp := map[string]any{ + "choices": []map[string]any{ + { + "message": map[string]any{"content": "ok"}, + "finish_reason": "stop", + }, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + // Simulate an OpenAI endpoint by overriding the apiBase after creation. + p := NewProvider("key", server.URL, "") + p.apiBase = "https://api.openai.com/v1" + // Point the HTTP client at our test server instead. + p.httpClient = &http.Client{ + Transport: roundTripperFunc(func(r *http.Request) (*http.Response, error) { + // Redirect all requests to the test server. + r.URL, _ = url.Parse(server.URL + r.URL.Path) + return http.DefaultTransport.RoundTrip(r) + }), + } + + _, err := p.Chat( + t.Context(), + []Message{{Role: "user", Content: "hi"}}, + nil, + "gpt-4o", + map[string]any{"prompt_cache_key": "agent-main"}, + ) + if err != nil { + t.Fatalf("Chat() error = %v", err) + } + if requestBody["prompt_cache_key"] != "agent-main" { + t.Fatalf("prompt_cache_key = %v, want %q", requestBody["prompt_cache_key"], "agent-main") + } +} + +func TestProviderChat_PromptCacheKeyOmittedForNonOpenAI(t *testing.T) { + tests := []struct { + name string + apiBase string + }{ + {"mistral", "https://api.mistral.ai/v1"}, + {"gemini", "https://generativelanguage.googleapis.com/v1beta"}, + {"deepseek", "https://api.deepseek.com/v1"}, + {"groq", "https://api.groq.com/openai/v1"}, + {"minimax", "https://api.minimaxi.com/v1"}, + {"ollama_local", "http://localhost:11434/v1"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var requestBody map[string]any + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + resp := map[string]any{ + "choices": []map[string]any{ + { + "message": map[string]any{"content": "ok"}, + "finish_reason": "stop", + }, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + p := NewProvider("key", server.URL, "") + p.apiBase = tt.apiBase + p.httpClient = &http.Client{ + Transport: roundTripperFunc(func(r *http.Request) (*http.Response, error) { + r.URL, _ = url.Parse(server.URL + r.URL.Path) + return http.DefaultTransport.RoundTrip(r) + }), + } + + _, err := p.Chat( + t.Context(), + []Message{{Role: "user", Content: "hi"}}, + nil, + "test-model", + map[string]any{"prompt_cache_key": "agent-main"}, + ) + if err != nil { + t.Fatalf("Chat() error = %v", err) + } + if _, exists := requestBody["prompt_cache_key"]; exists { + t.Fatalf("prompt_cache_key should NOT be sent to %s, but was included in request", tt.name) + } + }) + } +} + +func TestSupportsPromptCacheKey(t *testing.T) { + tests := []struct { + apiBase string + want bool + }{ + {"https://api.openai.com/v1", true}, + {"https://api.openai.com/v1/", true}, + {"https://api.mistral.ai/v1", false}, + {"https://generativelanguage.googleapis.com/v1beta", false}, + {"https://api.deepseek.com/v1", false}, + {"https://api.groq.com/openai/v1", false}, + {"http://localhost:11434/v1", false}, + {"https://openrouter.ai/api/v1", false}, + } + for _, tt := range tests { + if got := supportsPromptCacheKey(tt.apiBase); got != tt.want { + t.Errorf("supportsPromptCacheKey(%q) = %v, want %v", tt.apiBase, got, tt.want) + } + } +} + func TestSerializeMessages_StripsSystemParts(t *testing.T) { messages := []protocoltypes.Message{ {