mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
fix(openai_compat): improve prompt_cache_key host matching (#1387)
LGTM! The changes improve the robustness of prompt_cache_key host matching and add Azure OpenAI support. Thanks for the contribution!
This commit is contained in:
@@ -508,8 +508,13 @@ func asFloat(v any) (float64, bool) {
|
||||
|
||||
// supportsPromptCacheKey reports whether the given API base is known to
|
||||
// support the prompt_cache_key request field. Currently only OpenAI's own
|
||||
// API supports this. All other OpenAI-compatible providers (Mistral,
|
||||
// Gemini, DeepSeek, Groq, etc.) reject unknown fields with 422 errors.
|
||||
// API and Azure OpenAI support this. All other OpenAI-compatible providers
|
||||
// (Mistral, Gemini, DeepSeek, Groq, etc.) reject unknown fields with 422 errors.
|
||||
func supportsPromptCacheKey(apiBase string) bool {
|
||||
return strings.Contains(apiBase, "api.openai.com")
|
||||
u, err := url.Parse(apiBase)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
host := u.Hostname()
|
||||
return host == "api.openai.com" || strings.HasSuffix(host, ".openai.azure.com")
|
||||
}
|
||||
|
||||
@@ -718,7 +718,10 @@ func TestSerializeMessages_MediaWithToolCallID(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderChat_PromptCacheKeySentToOpenAI(t *testing.T) {
|
||||
// chatWithCacheKey sets up a test server, sends a Chat request with prompt_cache_key,
|
||||
// and returns the decoded request body for assertion.
|
||||
func chatWithCacheKey(t *testing.T, apiBase string) map[string]any {
|
||||
t.Helper()
|
||||
var requestBody map[string]any
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -739,13 +742,10 @@ func TestProviderChat_PromptCacheKeySentToOpenAI(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
// Simulate an OpenAI endpoint by overriding the apiBase after creation.
|
||||
p := NewProvider("key", server.URL, "")
|
||||
p.apiBase = "https://api.openai.com/v1"
|
||||
// Point the HTTP client at our test server instead.
|
||||
p.apiBase = apiBase
|
||||
p.httpClient = &http.Client{
|
||||
Transport: roundTripperFunc(func(r *http.Request) (*http.Response, error) {
|
||||
// Redirect all requests to the test server.
|
||||
r.URL, _ = url.Parse(server.URL + r.URL.Path)
|
||||
return http.DefaultTransport.RoundTrip(r)
|
||||
}),
|
||||
@@ -755,14 +755,19 @@ func TestProviderChat_PromptCacheKeySentToOpenAI(t *testing.T) {
|
||||
t.Context(),
|
||||
[]Message{{Role: "user", Content: "hi"}},
|
||||
nil,
|
||||
"gpt-4o",
|
||||
"test-model",
|
||||
map[string]any{"prompt_cache_key": "agent-main"},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error = %v", err)
|
||||
}
|
||||
if requestBody["prompt_cache_key"] != "agent-main" {
|
||||
t.Fatalf("prompt_cache_key = %v, want %q", requestBody["prompt_cache_key"], "agent-main")
|
||||
return requestBody
|
||||
}
|
||||
|
||||
func TestProviderChat_PromptCacheKeySentToOpenAI(t *testing.T) {
|
||||
body := chatWithCacheKey(t, "https://api.openai.com/v1")
|
||||
if body["prompt_cache_key"] != "agent-main" {
|
||||
t.Fatalf("prompt_cache_key = %v, want %q", body["prompt_cache_key"], "agent-main")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -781,46 +786,8 @@ func TestProviderChat_PromptCacheKeyOmittedForNonOpenAI(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var requestBody map[string]any
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
resp := map[string]any{
|
||||
"choices": []map[string]any{
|
||||
{
|
||||
"message": map[string]any{"content": "ok"},
|
||||
"finish_reason": "stop",
|
||||
},
|
||||
},
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
p := NewProvider("key", server.URL, "")
|
||||
p.apiBase = tt.apiBase
|
||||
p.httpClient = &http.Client{
|
||||
Transport: roundTripperFunc(func(r *http.Request) (*http.Response, error) {
|
||||
r.URL, _ = url.Parse(server.URL + r.URL.Path)
|
||||
return http.DefaultTransport.RoundTrip(r)
|
||||
}),
|
||||
}
|
||||
|
||||
_, err := p.Chat(
|
||||
t.Context(),
|
||||
[]Message{{Role: "user", Content: "hi"}},
|
||||
nil,
|
||||
"test-model",
|
||||
map[string]any{"prompt_cache_key": "agent-main"},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error = %v", err)
|
||||
}
|
||||
if _, exists := requestBody["prompt_cache_key"]; exists {
|
||||
body := chatWithCacheKey(t, tt.apiBase)
|
||||
if _, exists := body["prompt_cache_key"]; exists {
|
||||
t.Fatalf("prompt_cache_key should NOT be sent to %s, but was included in request", tt.name)
|
||||
}
|
||||
})
|
||||
@@ -834,12 +801,20 @@ func TestSupportsPromptCacheKey(t *testing.T) {
|
||||
}{
|
||||
{"https://api.openai.com/v1", true},
|
||||
{"https://api.openai.com/v1/", true},
|
||||
{"https://myresource.openai.azure.com/openai/deployments/gpt-4", true},
|
||||
{"https://eastus.openai.azure.com/v1", true},
|
||||
{"https://api.mistral.ai/v1", false},
|
||||
{"https://generativelanguage.googleapis.com/v1beta", false},
|
||||
{"https://api.deepseek.com/v1", false},
|
||||
{"https://api.groq.com/openai/v1", false},
|
||||
{"http://localhost:11434/v1", false},
|
||||
{"https://openrouter.ai/api/v1", false},
|
||||
// Edge cases: proxy URLs with openai.com in path should NOT match
|
||||
{"https://my-proxy.com/api.openai.com/v1", false},
|
||||
{"https://proxy.example.com/openai.azure.com/v1", false},
|
||||
// Malformed or empty
|
||||
{"", false},
|
||||
{"not-a-url", false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
if got := supportsPromptCacheKey(tt.apiBase); got != tt.want {
|
||||
|
||||
Reference in New Issue
Block a user