Files
picoclaw/pkg/providers/azure/provider_test.go
T
2026-03-15 12:45:11 +08:00

233 lines
6.9 KiB
Go

package azure
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"time"
)
// writeValidResponse writes a minimal valid Azure OpenAI chat completion response.
func writeValidResponse(w http.ResponseWriter) {
resp := map[string]any{
"choices": []map[string]any{
{
"message": map[string]any{"content": "ok"},
"finish_reason": "stop",
},
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}
func TestProviderChat_AzureURLConstruction(t *testing.T) {
var capturedPath string
var capturedAPIVersion string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedPath = r.URL.Path
capturedAPIVersion = r.URL.Query().Get("api-version")
writeValidResponse(w)
}))
defer server.Close()
p := NewProvider("test-key", server.URL, "")
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "my-gpt5-deployment", nil)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
wantPath := "/openai/deployments/my-gpt5-deployment/chat/completions"
if capturedPath != wantPath {
t.Errorf("URL path = %q, want %q", capturedPath, wantPath)
}
if capturedAPIVersion != azureAPIVersion {
t.Errorf("api-version = %q, want %q", capturedAPIVersion, azureAPIVersion)
}
}
func TestProviderChat_AzureAuthHeader(t *testing.T) {
var capturedAPIKey string
var capturedAuth string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedAPIKey = r.Header.Get("Api-Key")
capturedAuth = r.Header.Get("Authorization")
writeValidResponse(w)
}))
defer server.Close()
p := NewProvider("test-azure-key", server.URL, "")
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "deployment", nil)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
if capturedAPIKey != "test-azure-key" {
t.Errorf("api-key header = %q, want %q", capturedAPIKey, "test-azure-key")
}
if capturedAuth != "" {
t.Errorf("Authorization header should be empty, got %q", capturedAuth)
}
}
func TestProviderChat_AzureOmitsModelFromBody(t *testing.T) {
var requestBody map[string]any
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
json.NewDecoder(r.Body).Decode(&requestBody)
writeValidResponse(w)
}))
defer server.Close()
p := NewProvider("test-key", server.URL, "")
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "deployment", nil)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
if _, exists := requestBody["model"]; exists {
t.Error("request body should not contain 'model' field for Azure OpenAI")
}
}
func TestProviderChat_AzureUsesMaxCompletionTokens(t *testing.T) {
var requestBody map[string]any
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
json.NewDecoder(r.Body).Decode(&requestBody)
writeValidResponse(w)
}))
defer server.Close()
p := NewProvider("test-key", server.URL, "")
_, err := p.Chat(
t.Context(),
[]Message{{Role: "user", Content: "hi"}},
nil,
"deployment",
map[string]any{"max_tokens": 2048},
)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
if _, exists := requestBody["max_completion_tokens"]; !exists {
t.Error("request body should contain 'max_completion_tokens'")
}
if _, exists := requestBody["max_tokens"]; exists {
t.Error("request body should not contain 'max_tokens'")
}
}
func TestProviderChat_AzureHTTPError(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, `{"error":"unauthorized"}`, http.StatusUnauthorized)
}))
defer server.Close()
p := NewProvider("bad-key", server.URL, "")
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "deployment", nil)
if err == nil {
t.Fatal("expected error, got nil")
}
}
func TestProviderChat_AzureParseToolCalls(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
resp := map[string]any{
"choices": []map[string]any{
{
"message": map[string]any{
"content": "",
"tool_calls": []map[string]any{
{
"id": "call_1",
"type": "function",
"function": map[string]any{
"name": "get_weather",
"arguments": `{"city":"Seattle"}`,
},
},
},
},
"finish_reason": "tool_calls",
},
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
p := NewProvider("test-key", server.URL, "")
out, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "weather?"}}, nil, "deployment", nil)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
if len(out.ToolCalls) != 1 {
t.Fatalf("len(ToolCalls) = %d, want 1", len(out.ToolCalls))
}
if out.ToolCalls[0].Name != "get_weather" {
t.Errorf("ToolCalls[0].Name = %q, want %q", out.ToolCalls[0].Name, "get_weather")
}
}
func TestProvider_AzureEmptyAPIBase(t *testing.T) {
p := NewProvider("test-key", "", "")
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "deployment", nil)
if err == nil {
t.Fatal("expected error for empty API base")
}
}
func TestProvider_AzureRequestTimeoutDefault(t *testing.T) {
p := NewProvider("test-key", "https://example.com", "")
if p.httpClient.Timeout != defaultRequestTimeout {
t.Errorf("timeout = %v, want %v", p.httpClient.Timeout, defaultRequestTimeout)
}
}
func TestProvider_AzureRequestTimeoutOverride(t *testing.T) {
p := NewProvider("test-key", "https://example.com", "", WithRequestTimeout(300*time.Second))
if p.httpClient.Timeout != 300*time.Second {
t.Errorf("timeout = %v, want %v", p.httpClient.Timeout, 300*time.Second)
}
}
func TestProvider_AzureNewProviderWithTimeout(t *testing.T) {
p := NewProviderWithTimeout("test-key", "https://example.com", "", 180)
if p.httpClient.Timeout != 180*time.Second {
t.Errorf("timeout = %v, want %v", p.httpClient.Timeout, 180*time.Second)
}
}
func TestProviderChat_AzureDeploymentNameEscaped(t *testing.T) {
var capturedPath string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedPath = r.URL.RawPath // use RawPath to see percent-encoding
if capturedPath == "" {
capturedPath = r.URL.Path
}
writeValidResponse(w)
}))
defer server.Close()
p := NewProvider("test-key", server.URL, "")
// Deployment name with characters that could cause path injection
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "my deploy/../../admin", nil)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
// The slash and special chars in the deployment name must be escaped, not treated as path separators
if capturedPath == "/openai/deployments/my deploy/../../admin/chat/completions" {
t.Fatal("deployment name was interpolated without escaping — path injection possible")
}
}