fix(openai_compat): improve prompt_cache_key host matching (#1387)

LGTM! The changes improve the robustness of prompt_cache_key host matching and add Azure OpenAI support. Thanks for the contribution!
This commit is contained in:
Mahendra Teja
2026-03-12 00:54:31 +05:30
committed by GitHub
parent 49204df678
commit 22735aaee4
2 changed files with 31 additions and 51 deletions
+8 -3
View File
@@ -508,8 +508,13 @@ func asFloat(v any) (float64, bool) {
// supportsPromptCacheKey reports whether the given API base is known to
// support the prompt_cache_key request field. Currently only OpenAI's own
// API supports this. All other OpenAI-compatible providers (Mistral,
// Gemini, DeepSeek, Groq, etc.) reject unknown fields with 422 errors.
// API and Azure OpenAI support this. All other OpenAI-compatible providers
// (Mistral, Gemini, DeepSeek, Groq, etc.) reject unknown fields with 422 errors.
func supportsPromptCacheKey(apiBase string) bool {
return strings.Contains(apiBase, "api.openai.com")
u, err := url.Parse(apiBase)
if err != nil {
return false
}
host := u.Hostname()
return host == "api.openai.com" || strings.HasSuffix(host, ".openai.azure.com")
}
+23 -48
View File
@@ -718,7 +718,10 @@ func TestSerializeMessages_MediaWithToolCallID(t *testing.T) {
}
}
func TestProviderChat_PromptCacheKeySentToOpenAI(t *testing.T) {
// chatWithCacheKey sets up a test server, sends a Chat request with prompt_cache_key,
// and returns the decoded request body for assertion.
func chatWithCacheKey(t *testing.T, apiBase string) map[string]any {
t.Helper()
var requestBody map[string]any
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -739,13 +742,10 @@ func TestProviderChat_PromptCacheKeySentToOpenAI(t *testing.T) {
}))
defer server.Close()
// Simulate an OpenAI endpoint by overriding the apiBase after creation.
p := NewProvider("key", server.URL, "")
p.apiBase = "https://api.openai.com/v1"
// Point the HTTP client at our test server instead.
p.apiBase = apiBase
p.httpClient = &http.Client{
Transport: roundTripperFunc(func(r *http.Request) (*http.Response, error) {
// Redirect all requests to the test server.
r.URL, _ = url.Parse(server.URL + r.URL.Path)
return http.DefaultTransport.RoundTrip(r)
}),
@@ -755,14 +755,19 @@ func TestProviderChat_PromptCacheKeySentToOpenAI(t *testing.T) {
t.Context(),
[]Message{{Role: "user", Content: "hi"}},
nil,
"gpt-4o",
"test-model",
map[string]any{"prompt_cache_key": "agent-main"},
)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
if requestBody["prompt_cache_key"] != "agent-main" {
t.Fatalf("prompt_cache_key = %v, want %q", requestBody["prompt_cache_key"], "agent-main")
return requestBody
}
func TestProviderChat_PromptCacheKeySentToOpenAI(t *testing.T) {
body := chatWithCacheKey(t, "https://api.openai.com/v1")
if body["prompt_cache_key"] != "agent-main" {
t.Fatalf("prompt_cache_key = %v, want %q", body["prompt_cache_key"], "agent-main")
}
}
@@ -781,46 +786,8 @@ func TestProviderChat_PromptCacheKeyOmittedForNonOpenAI(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var requestBody map[string]any
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
resp := map[string]any{
"choices": []map[string]any{
{
"message": map[string]any{"content": "ok"},
"finish_reason": "stop",
},
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
p := NewProvider("key", server.URL, "")
p.apiBase = tt.apiBase
p.httpClient = &http.Client{
Transport: roundTripperFunc(func(r *http.Request) (*http.Response, error) {
r.URL, _ = url.Parse(server.URL + r.URL.Path)
return http.DefaultTransport.RoundTrip(r)
}),
}
_, err := p.Chat(
t.Context(),
[]Message{{Role: "user", Content: "hi"}},
nil,
"test-model",
map[string]any{"prompt_cache_key": "agent-main"},
)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
if _, exists := requestBody["prompt_cache_key"]; exists {
body := chatWithCacheKey(t, tt.apiBase)
if _, exists := body["prompt_cache_key"]; exists {
t.Fatalf("prompt_cache_key should NOT be sent to %s, but was included in request", tt.name)
}
})
@@ -834,12 +801,20 @@ func TestSupportsPromptCacheKey(t *testing.T) {
}{
{"https://api.openai.com/v1", true},
{"https://api.openai.com/v1/", true},
{"https://myresource.openai.azure.com/openai/deployments/gpt-4", true},
{"https://eastus.openai.azure.com/v1", true},
{"https://api.mistral.ai/v1", false},
{"https://generativelanguage.googleapis.com/v1beta", false},
{"https://api.deepseek.com/v1", false},
{"https://api.groq.com/openai/v1", false},
{"http://localhost:11434/v1", false},
{"https://openrouter.ai/api/v1", false},
// Edge cases: proxy URLs with openai.com in path should NOT match
{"https://my-proxy.com/api.openai.com/v1", false},
{"https://proxy.example.com/openai.azure.com/v1", false},
// Malformed or empty
{"", false},
{"not-a-url", false},
}
for _, tt := range tests {
if got := supportsPromptCacheKey(tt.apiBase); got != tt.want {