From 6612ca099a2833c449a6c363fae11146a849f15d Mon Sep 17 00:00:00 2001 From: Mahendra Teja <109858254+mahendrateja95@users.noreply.github.com> Date: Thu, 12 Mar 2026 00:54:31 +0530 Subject: [PATCH] fix(openai_compat): improve prompt_cache_key host matching (#1387) LGTM! The changes improve the robustness of prompt_cache_key host matching and add Azure OpenAI support. Thanks for the contribution! --- pkg/providers/openai_compat/provider.go | 11 ++- pkg/providers/openai_compat/provider_test.go | 71 +++++++------------- 2 files changed, 31 insertions(+), 51 deletions(-) diff --git a/pkg/providers/openai_compat/provider.go b/pkg/providers/openai_compat/provider.go index d5f4bdfce..f97bf3acd 100644 --- a/pkg/providers/openai_compat/provider.go +++ b/pkg/providers/openai_compat/provider.go @@ -508,8 +508,13 @@ func asFloat(v any) (float64, bool) { // supportsPromptCacheKey reports whether the given API base is known to // support the prompt_cache_key request field. Currently only OpenAI's own -// API supports this. All other OpenAI-compatible providers (Mistral, -// Gemini, DeepSeek, Groq, etc.) reject unknown fields with 422 errors. +// API and Azure OpenAI support this. All other OpenAI-compatible providers +// (Mistral, Gemini, DeepSeek, Groq, etc.) reject unknown fields with 422 errors. func supportsPromptCacheKey(apiBase string) bool { - return strings.Contains(apiBase, "api.openai.com") + u, err := url.Parse(apiBase) + if err != nil { + return false + } + host := u.Hostname() + return host == "api.openai.com" || strings.HasSuffix(host, ".openai.azure.com") } diff --git a/pkg/providers/openai_compat/provider_test.go b/pkg/providers/openai_compat/provider_test.go index 39aff1d1a..41f278a1b 100644 --- a/pkg/providers/openai_compat/provider_test.go +++ b/pkg/providers/openai_compat/provider_test.go @@ -718,7 +718,10 @@ func TestSerializeMessages_MediaWithToolCallID(t *testing.T) { } } -func TestProviderChat_PromptCacheKeySentToOpenAI(t *testing.T) { +// chatWithCacheKey sets up a test server, sends a Chat request with prompt_cache_key, +// and returns the decoded request body for assertion. +func chatWithCacheKey(t *testing.T, apiBase string) map[string]any { + t.Helper() var requestBody map[string]any server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -739,13 +742,10 @@ func TestProviderChat_PromptCacheKeySentToOpenAI(t *testing.T) { })) defer server.Close() - // Simulate an OpenAI endpoint by overriding the apiBase after creation. p := NewProvider("key", server.URL, "") - p.apiBase = "https://api.openai.com/v1" - // Point the HTTP client at our test server instead. + p.apiBase = apiBase p.httpClient = &http.Client{ Transport: roundTripperFunc(func(r *http.Request) (*http.Response, error) { - // Redirect all requests to the test server. r.URL, _ = url.Parse(server.URL + r.URL.Path) return http.DefaultTransport.RoundTrip(r) }), @@ -755,14 +755,19 @@ func TestProviderChat_PromptCacheKeySentToOpenAI(t *testing.T) { t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, - "gpt-4o", + "test-model", map[string]any{"prompt_cache_key": "agent-main"}, ) if err != nil { t.Fatalf("Chat() error = %v", err) } - if requestBody["prompt_cache_key"] != "agent-main" { - t.Fatalf("prompt_cache_key = %v, want %q", requestBody["prompt_cache_key"], "agent-main") + return requestBody +} + +func TestProviderChat_PromptCacheKeySentToOpenAI(t *testing.T) { + body := chatWithCacheKey(t, "https://api.openai.com/v1") + if body["prompt_cache_key"] != "agent-main" { + t.Fatalf("prompt_cache_key = %v, want %q", body["prompt_cache_key"], "agent-main") } } @@ -781,46 +786,8 @@ func TestProviderChat_PromptCacheKeyOmittedForNonOpenAI(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - var requestBody map[string]any - - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - resp := map[string]any{ - "choices": []map[string]any{ - { - "message": map[string]any{"content": "ok"}, - "finish_reason": "stop", - }, - }, - } - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(resp) - })) - defer server.Close() - - p := NewProvider("key", server.URL, "") - p.apiBase = tt.apiBase - p.httpClient = &http.Client{ - Transport: roundTripperFunc(func(r *http.Request) (*http.Response, error) { - r.URL, _ = url.Parse(server.URL + r.URL.Path) - return http.DefaultTransport.RoundTrip(r) - }), - } - - _, err := p.Chat( - t.Context(), - []Message{{Role: "user", Content: "hi"}}, - nil, - "test-model", - map[string]any{"prompt_cache_key": "agent-main"}, - ) - if err != nil { - t.Fatalf("Chat() error = %v", err) - } - if _, exists := requestBody["prompt_cache_key"]; exists { + body := chatWithCacheKey(t, tt.apiBase) + if _, exists := body["prompt_cache_key"]; exists { t.Fatalf("prompt_cache_key should NOT be sent to %s, but was included in request", tt.name) } }) @@ -834,12 +801,20 @@ func TestSupportsPromptCacheKey(t *testing.T) { }{ {"https://api.openai.com/v1", true}, {"https://api.openai.com/v1/", true}, + {"https://myresource.openai.azure.com/openai/deployments/gpt-4", true}, + {"https://eastus.openai.azure.com/v1", true}, {"https://api.mistral.ai/v1", false}, {"https://generativelanguage.googleapis.com/v1beta", false}, {"https://api.deepseek.com/v1", false}, {"https://api.groq.com/openai/v1", false}, {"http://localhost:11434/v1", false}, {"https://openrouter.ai/api/v1", false}, + // Edge cases: proxy URLs with openai.com in path should NOT match + {"https://my-proxy.com/api.openai.com/v1", false}, + {"https://proxy.example.com/openai.azure.com/v1", false}, + // Malformed or empty + {"", false}, + {"not-a-url", false}, } for _, tt := range tests { if got := supportsPromptCacheKey(tt.apiBase); got != tt.want {