Files
picoclaw/pkg/providers/azure/provider_test.go
T
Cytown 2c446e1e07 feat: add userAgent config for ModelConfig (#2242)
* feat: add userAgent config for ModelConfig

* update docs for ModelConfig.userAgent

* make defaut userAgent to PicoClaw and add test case
2026-04-02 11:44:13 +08:00

418 lines
13 KiB
Go

package azure
import (
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
)
// writeValidResponse writes a minimal valid Responses API response.
func writeValidResponse(w http.ResponseWriter) {
resp := map[string]any{
"id": "resp_test",
"object": "response",
"status": "completed",
"output": []map[string]any{
{
"type": "message",
"content": []map[string]any{
{"type": "output_text", "text": "ok"},
},
},
},
"usage": map[string]any{
"input_tokens": 5,
"output_tokens": 2,
"total_tokens": 7,
"input_tokens_details": map[string]any{"cached_tokens": 0},
"output_tokens_details": map[string]any{"reasoning_tokens": 0},
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}
func TestProviderChat_AzureURLConstruction(t *testing.T) {
var capturedPath string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedPath = r.URL.Path
writeValidResponse(w)
}))
defer server.Close()
p := NewProvider("test-key", server.URL, "", "")
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "my-gpt5-deployment", nil)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
wantPath := "/openai/v1/responses"
if capturedPath != wantPath {
t.Errorf("URL path = %q, want %q", capturedPath, wantPath)
}
}
func TestProviderChat_AzureAuthHeader(t *testing.T) {
var capturedAuth string
var capturedAPIKey string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedAuth = r.Header.Get("Authorization")
capturedAPIKey = r.Header.Get("Api-Key")
writeValidResponse(w)
}))
defer server.Close()
p := NewProvider("test-azure-key", server.URL, "", "")
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "deployment", nil)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
if capturedAuth != "Bearer test-azure-key" {
t.Errorf("Authorization header = %q, want %q", capturedAuth, "Bearer test-azure-key")
}
if capturedAPIKey != "" {
t.Errorf("Api-Key header should be empty, got %q", capturedAPIKey)
}
}
func TestProviderChat_AzureRequestBodyContainsModel(t *testing.T) {
var requestBody map[string]any
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
json.NewDecoder(r.Body).Decode(&requestBody)
writeValidResponse(w)
}))
defer server.Close()
p := NewProvider("test-key", server.URL, "", "")
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "my-deployment", nil)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
if requestBody["model"] != "my-deployment" {
t.Errorf("model = %v, want %q", requestBody["model"], "my-deployment")
}
}
func TestProviderChat_AzureUsesMaxOutputTokens(t *testing.T) {
var requestBody map[string]any
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
json.NewDecoder(r.Body).Decode(&requestBody)
writeValidResponse(w)
}))
defer server.Close()
p := NewProvider("test-key", server.URL, "", "")
_, err := p.Chat(
t.Context(),
[]Message{{Role: "user", Content: "hi"}},
nil,
"deployment",
map[string]any{"max_tokens": 2048},
)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
if requestBody["max_output_tokens"] == nil {
t.Error("request body should contain 'max_output_tokens'")
}
if _, exists := requestBody["max_tokens"]; exists {
t.Error("request body should not contain 'max_tokens'")
}
if _, exists := requestBody["max_completion_tokens"]; exists {
t.Error("request body should not contain 'max_completion_tokens'")
}
}
func TestProviderChat_AzureStoreIsFalse(t *testing.T) {
var requestBody map[string]any
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
json.NewDecoder(r.Body).Decode(&requestBody)
writeValidResponse(w)
}))
defer server.Close()
p := NewProvider("test-key", server.URL, "", "")
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "deployment", nil)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
if requestBody["store"] != false {
t.Errorf("store = %v, want false", requestBody["store"])
}
}
func TestProviderChat_AzureHTTPError(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, `{"error":"unauthorized"}`, http.StatusUnauthorized)
}))
defer server.Close()
p := NewProvider("bad-key", server.URL, "", "")
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "deployment", nil)
if err == nil {
t.Fatal("expected error, got nil")
}
}
func TestProviderChat_AzureRateLimitError(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusTooManyRequests)
w.Write([]byte(`{"error":{"message":"Rate limit exceeded","type":"rate_limit_error"}}`))
}))
defer server.Close()
p := NewProvider("test-key", server.URL, "", "")
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "deployment", nil)
if err == nil {
t.Fatal("expected error for 429, got nil")
}
if !strings.Contains(err.Error(), "429") {
t.Errorf("error should contain status code 429, got: %v", err)
}
}
func TestProviderChat_AzureServerError(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(`{"error":{"message":"Internal server error","type":"server_error"}}`))
}))
defer server.Close()
p := NewProvider("test-key", server.URL, "", "")
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "deployment", nil)
if err == nil {
t.Fatal("expected error for 500, got nil")
}
if !strings.Contains(err.Error(), "500") {
t.Errorf("error should contain status code 500, got: %v", err)
}
}
func TestProviderChat_AzureParseTextOutput(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
resp := map[string]any{
"id": "resp_1",
"object": "response",
"status": "completed",
"output": []map[string]any{
{
"type": "message",
"content": []map[string]any{
{"type": "output_text", "text": "Hello there!"},
},
},
},
"usage": map[string]any{
"input_tokens": 10, "output_tokens": 5, "total_tokens": 15,
"input_tokens_details": map[string]any{"cached_tokens": 0},
"output_tokens_details": map[string]any{"reasoning_tokens": 0},
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
p := NewProvider("test-key", server.URL, "", "")
out, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "deployment", nil)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
if out.Content != "Hello there!" {
t.Errorf("Content = %q, want %q", out.Content, "Hello there!")
}
if out.FinishReason != "stop" {
t.Errorf("FinishReason = %q, want %q", out.FinishReason, "stop")
}
if out.Usage.TotalTokens != 15 {
t.Errorf("TotalTokens = %d, want 15", out.Usage.TotalTokens)
}
}
func TestProviderChat_AzureParseToolCalls(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
resp := map[string]any{
"id": "resp_2",
"object": "response",
"status": "completed",
"output": []map[string]any{
{
"type": "function_call",
"call_id": "call_1",
"name": "get_weather",
"arguments": `{"city":"Seattle"}`,
},
},
"usage": map[string]any{
"input_tokens": 10, "output_tokens": 8, "total_tokens": 18,
"input_tokens_details": map[string]any{"cached_tokens": 0},
"output_tokens_details": map[string]any{"reasoning_tokens": 0},
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
p := NewProvider("test-key", server.URL, "", "")
out, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "weather?"}}, nil, "deployment", nil)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
if len(out.ToolCalls) != 1 {
t.Fatalf("len(ToolCalls) = %d, want 1", len(out.ToolCalls))
}
if out.ToolCalls[0].Name != "get_weather" {
t.Errorf("ToolCalls[0].Name = %q, want %q", out.ToolCalls[0].Name, "get_weather")
}
if out.FinishReason != "tool_calls" {
t.Errorf("FinishReason = %q, want %q", out.FinishReason, "tool_calls")
}
}
func TestProvider_AzureEmptyAPIBase(t *testing.T) {
p := NewProvider("test-key", "", "", "")
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "deployment", nil)
if err == nil {
t.Fatal("expected error for empty API base")
}
}
func TestProvider_AzureRequestTimeoutDefault(t *testing.T) {
p := NewProvider("test-key", "https://example.com", "", "")
if p.httpClient.Timeout != defaultRequestTimeout {
t.Errorf("timeout = %v, want %v", p.httpClient.Timeout, defaultRequestTimeout)
}
}
func TestProvider_AzureRequestTimeoutOverride(t *testing.T) {
p := NewProvider("test-key", "https://example.com", "", "", WithRequestTimeout(300*time.Second))
if p.httpClient.Timeout != 300*time.Second {
t.Errorf("timeout = %v, want %v", p.httpClient.Timeout, 300*time.Second)
}
}
func TestProvider_AzureNewProviderWithTimeout(t *testing.T) {
p := NewProviderWithTimeout("test-key", "https://example.com", "", "", 180)
if p.httpClient.Timeout != 180*time.Second {
t.Errorf("timeout = %v, want %v", p.httpClient.Timeout, 180*time.Second)
}
}
func TestProviderChat_AzureNativeWebSearchInjection(t *testing.T) {
var requestBody map[string]any
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
json.NewDecoder(r.Body).Decode(&requestBody)
writeValidResponse(w)
}))
defer server.Close()
tools := []ToolDefinition{
{
Type: "function",
Function: protocoltypes.ToolFunctionDefinition{
Name: "web_search",
Description: "local web search",
Parameters: map[string]any{"type": "object"},
},
},
{
Type: "function",
Function: protocoltypes.ToolFunctionDefinition{
Name: "read_file",
Description: "read a file",
Parameters: map[string]any{"type": "object"},
},
},
}
p := NewProvider("test-key", server.URL, "", "")
// With native_search=true: user-defined web_search should be replaced by built-in
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, tools, "deployment",
map[string]any{"native_search": true})
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
toolsAny, ok := requestBody["tools"].([]any)
if !ok {
t.Fatal("request body should contain 'tools' array")
}
if len(toolsAny) != 2 {
t.Fatalf("len(tools) = %d, want 2 (read_file + web_search builtin)", len(toolsAny))
}
// First tool should be read_file (user-defined web_search was skipped)
firstTool, _ := toolsAny[0].(map[string]any)
if firstTool["name"] != "read_file" {
t.Errorf("first tool name = %v, want %q", firstTool["name"], "read_file")
}
// Second tool should be built-in web_search
secondTool, _ := toolsAny[1].(map[string]any)
if secondTool["type"] != "web_search" {
t.Errorf("second tool type = %v, want %q", secondTool["type"], "web_search")
}
}
func TestProviderChat_AzureNoNativeWebSearch(t *testing.T) {
var requestBody map[string]any
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
json.NewDecoder(r.Body).Decode(&requestBody)
writeValidResponse(w)
}))
defer server.Close()
tools := []ToolDefinition{
{
Type: "function",
Function: protocoltypes.ToolFunctionDefinition{
Name: "web_search",
Description: "local web search",
Parameters: map[string]any{"type": "object"},
},
},
}
p := NewProvider("test-key", server.URL, "", "")
// Without native_search: user-defined web_search should be kept as-is
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, tools, "deployment", nil)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
toolsAny, ok := requestBody["tools"].([]any)
if !ok {
t.Fatal("request body should contain 'tools' array")
}
if len(toolsAny) != 1 {
t.Fatalf("len(tools) = %d, want 1", len(toolsAny))
}
// Should be the user-defined function tool, not built-in
tool, _ := toolsAny[0].(map[string]any)
if tool["type"] != "function" {
t.Errorf("tool type = %v, want %q", tool["type"], "function")
}
}