mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
233 lines
6.9 KiB
Go
233 lines
6.9 KiB
Go
package azure
|
|
|
|
import (
|
|
"encoding/json"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
// writeValidResponse writes a minimal valid Azure OpenAI chat completion response.
|
|
func writeValidResponse(w http.ResponseWriter) {
|
|
resp := map[string]any{
|
|
"choices": []map[string]any{
|
|
{
|
|
"message": map[string]any{"content": "ok"},
|
|
"finish_reason": "stop",
|
|
},
|
|
},
|
|
}
|
|
w.Header().Set("Content-Type", "application/json")
|
|
json.NewEncoder(w).Encode(resp)
|
|
}
|
|
|
|
func TestProviderChat_AzureURLConstruction(t *testing.T) {
|
|
var capturedPath string
|
|
var capturedAPIVersion string
|
|
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
capturedPath = r.URL.Path
|
|
capturedAPIVersion = r.URL.Query().Get("api-version")
|
|
writeValidResponse(w)
|
|
}))
|
|
defer server.Close()
|
|
|
|
p := NewProvider("test-key", server.URL, "")
|
|
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "my-gpt5-deployment", nil)
|
|
if err != nil {
|
|
t.Fatalf("Chat() error = %v", err)
|
|
}
|
|
|
|
wantPath := "/openai/deployments/my-gpt5-deployment/chat/completions"
|
|
if capturedPath != wantPath {
|
|
t.Errorf("URL path = %q, want %q", capturedPath, wantPath)
|
|
}
|
|
if capturedAPIVersion != azureAPIVersion {
|
|
t.Errorf("api-version = %q, want %q", capturedAPIVersion, azureAPIVersion)
|
|
}
|
|
}
|
|
|
|
func TestProviderChat_AzureAuthHeader(t *testing.T) {
|
|
var capturedAPIKey string
|
|
var capturedAuth string
|
|
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
capturedAPIKey = r.Header.Get("Api-Key")
|
|
capturedAuth = r.Header.Get("Authorization")
|
|
writeValidResponse(w)
|
|
}))
|
|
defer server.Close()
|
|
|
|
p := NewProvider("test-azure-key", server.URL, "")
|
|
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "deployment", nil)
|
|
if err != nil {
|
|
t.Fatalf("Chat() error = %v", err)
|
|
}
|
|
|
|
if capturedAPIKey != "test-azure-key" {
|
|
t.Errorf("api-key header = %q, want %q", capturedAPIKey, "test-azure-key")
|
|
}
|
|
if capturedAuth != "" {
|
|
t.Errorf("Authorization header should be empty, got %q", capturedAuth)
|
|
}
|
|
}
|
|
|
|
func TestProviderChat_AzureOmitsModelFromBody(t *testing.T) {
|
|
var requestBody map[string]any
|
|
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
json.NewDecoder(r.Body).Decode(&requestBody)
|
|
writeValidResponse(w)
|
|
}))
|
|
defer server.Close()
|
|
|
|
p := NewProvider("test-key", server.URL, "")
|
|
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "deployment", nil)
|
|
if err != nil {
|
|
t.Fatalf("Chat() error = %v", err)
|
|
}
|
|
|
|
if _, exists := requestBody["model"]; exists {
|
|
t.Error("request body should not contain 'model' field for Azure OpenAI")
|
|
}
|
|
}
|
|
|
|
func TestProviderChat_AzureUsesMaxCompletionTokens(t *testing.T) {
|
|
var requestBody map[string]any
|
|
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
json.NewDecoder(r.Body).Decode(&requestBody)
|
|
writeValidResponse(w)
|
|
}))
|
|
defer server.Close()
|
|
|
|
p := NewProvider("test-key", server.URL, "")
|
|
_, err := p.Chat(
|
|
t.Context(),
|
|
[]Message{{Role: "user", Content: "hi"}},
|
|
nil,
|
|
"deployment",
|
|
map[string]any{"max_tokens": 2048},
|
|
)
|
|
if err != nil {
|
|
t.Fatalf("Chat() error = %v", err)
|
|
}
|
|
|
|
if _, exists := requestBody["max_completion_tokens"]; !exists {
|
|
t.Error("request body should contain 'max_completion_tokens'")
|
|
}
|
|
if _, exists := requestBody["max_tokens"]; exists {
|
|
t.Error("request body should not contain 'max_tokens'")
|
|
}
|
|
}
|
|
|
|
func TestProviderChat_AzureHTTPError(t *testing.T) {
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
http.Error(w, `{"error":"unauthorized"}`, http.StatusUnauthorized)
|
|
}))
|
|
defer server.Close()
|
|
|
|
p := NewProvider("bad-key", server.URL, "")
|
|
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "deployment", nil)
|
|
if err == nil {
|
|
t.Fatal("expected error, got nil")
|
|
}
|
|
}
|
|
|
|
func TestProviderChat_AzureParseToolCalls(t *testing.T) {
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
resp := map[string]any{
|
|
"choices": []map[string]any{
|
|
{
|
|
"message": map[string]any{
|
|
"content": "",
|
|
"tool_calls": []map[string]any{
|
|
{
|
|
"id": "call_1",
|
|
"type": "function",
|
|
"function": map[string]any{
|
|
"name": "get_weather",
|
|
"arguments": `{"city":"Seattle"}`,
|
|
},
|
|
},
|
|
},
|
|
},
|
|
"finish_reason": "tool_calls",
|
|
},
|
|
},
|
|
}
|
|
w.Header().Set("Content-Type", "application/json")
|
|
json.NewEncoder(w).Encode(resp)
|
|
}))
|
|
defer server.Close()
|
|
|
|
p := NewProvider("test-key", server.URL, "")
|
|
out, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "weather?"}}, nil, "deployment", nil)
|
|
if err != nil {
|
|
t.Fatalf("Chat() error = %v", err)
|
|
}
|
|
|
|
if len(out.ToolCalls) != 1 {
|
|
t.Fatalf("len(ToolCalls) = %d, want 1", len(out.ToolCalls))
|
|
}
|
|
if out.ToolCalls[0].Name != "get_weather" {
|
|
t.Errorf("ToolCalls[0].Name = %q, want %q", out.ToolCalls[0].Name, "get_weather")
|
|
}
|
|
}
|
|
|
|
func TestProvider_AzureEmptyAPIBase(t *testing.T) {
|
|
p := NewProvider("test-key", "", "")
|
|
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "deployment", nil)
|
|
if err == nil {
|
|
t.Fatal("expected error for empty API base")
|
|
}
|
|
}
|
|
|
|
func TestProvider_AzureRequestTimeoutDefault(t *testing.T) {
|
|
p := NewProvider("test-key", "https://example.com", "")
|
|
if p.httpClient.Timeout != defaultRequestTimeout {
|
|
t.Errorf("timeout = %v, want %v", p.httpClient.Timeout, defaultRequestTimeout)
|
|
}
|
|
}
|
|
|
|
func TestProvider_AzureRequestTimeoutOverride(t *testing.T) {
|
|
p := NewProvider("test-key", "https://example.com", "", WithRequestTimeout(300*time.Second))
|
|
if p.httpClient.Timeout != 300*time.Second {
|
|
t.Errorf("timeout = %v, want %v", p.httpClient.Timeout, 300*time.Second)
|
|
}
|
|
}
|
|
|
|
func TestProvider_AzureNewProviderWithTimeout(t *testing.T) {
|
|
p := NewProviderWithTimeout("test-key", "https://example.com", "", 180)
|
|
if p.httpClient.Timeout != 180*time.Second {
|
|
t.Errorf("timeout = %v, want %v", p.httpClient.Timeout, 180*time.Second)
|
|
}
|
|
}
|
|
|
|
func TestProviderChat_AzureDeploymentNameEscaped(t *testing.T) {
|
|
var capturedPath string
|
|
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
capturedPath = r.URL.RawPath // use RawPath to see percent-encoding
|
|
if capturedPath == "" {
|
|
capturedPath = r.URL.Path
|
|
}
|
|
writeValidResponse(w)
|
|
}))
|
|
defer server.Close()
|
|
|
|
p := NewProvider("test-key", server.URL, "")
|
|
|
|
// Deployment name with characters that could cause path injection
|
|
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "my deploy/../../admin", nil)
|
|
if err != nil {
|
|
t.Fatalf("Chat() error = %v", err)
|
|
}
|
|
|
|
// The slash and special chars in the deployment name must be escaped, not treated as path separators
|
|
if capturedPath == "/openai/deployments/my deploy/../../admin/chat/completions" {
|
|
t.Fatal("deployment name was interpolated without escaping — path injection possible")
|
|
}
|
|
}
|