mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
feat(provider): add support for azure openai provider (#1422)
* Add support for azure openai provider * Add checks for deployment model name * Apply suggestion from @Copilot Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Addressing @Copilot suggestion to remove the init() function which seemed redundant * Fix readme * Fix linting checks --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -0,0 +1,380 @@
|
||||
// PicoClaw - Ultra-lightweight personal AI agent
|
||||
// License: MIT
|
||||
//
|
||||
// Copyright (c) 2026 PicoClaw contributors
|
||||
|
||||
// Package common provides shared utilities used by multiple LLM provider
|
||||
// implementations (openai_compat, azure, etc.).
|
||||
package common
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
|
||||
)
|
||||
|
||||
// Re-export protocol types used across providers.
|
||||
type (
|
||||
ToolCall = protocoltypes.ToolCall
|
||||
FunctionCall = protocoltypes.FunctionCall
|
||||
LLMResponse = protocoltypes.LLMResponse
|
||||
UsageInfo = protocoltypes.UsageInfo
|
||||
Message = protocoltypes.Message
|
||||
ToolDefinition = protocoltypes.ToolDefinition
|
||||
ToolFunctionDefinition = protocoltypes.ToolFunctionDefinition
|
||||
ExtraContent = protocoltypes.ExtraContent
|
||||
GoogleExtra = protocoltypes.GoogleExtra
|
||||
ReasoningDetail = protocoltypes.ReasoningDetail
|
||||
)
|
||||
|
||||
const DefaultRequestTimeout = 120 * time.Second
|
||||
|
||||
// NewHTTPClient creates an *http.Client with an optional proxy and the default timeout.
|
||||
func NewHTTPClient(proxy string) *http.Client {
|
||||
client := &http.Client{
|
||||
Timeout: DefaultRequestTimeout,
|
||||
}
|
||||
if proxy != "" {
|
||||
parsed, err := url.Parse(proxy)
|
||||
if err == nil {
|
||||
// Preserve http.DefaultTransport settings (TLS, HTTP/2, timeouts, etc.)
|
||||
if base, ok := http.DefaultTransport.(*http.Transport); ok {
|
||||
tr := base.Clone()
|
||||
tr.Proxy = http.ProxyURL(parsed)
|
||||
client.Transport = tr
|
||||
} else {
|
||||
// Fallback: minimal transport if DefaultTransport is not *http.Transport.
|
||||
client.Transport = &http.Transport{
|
||||
Proxy: http.ProxyURL(parsed),
|
||||
}
|
||||
}
|
||||
} else {
|
||||
log.Printf("common: invalid proxy URL %q: %v", proxy, err)
|
||||
}
|
||||
}
|
||||
return client
|
||||
}
|
||||
|
||||
// --- Message serialization ---
|
||||
|
||||
// openaiMessage is the wire-format message for OpenAI-compatible APIs.
|
||||
// It mirrors protocoltypes.Message but omits SystemParts, which is an
|
||||
// internal field that would be unknown to third-party endpoints.
|
||||
type openaiMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
ReasoningContent string `json:"reasoning_content,omitempty"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
}
|
||||
|
||||
// SerializeMessages converts internal Message structs to the OpenAI wire format.
|
||||
// - Strips SystemParts (unknown to third-party endpoints)
|
||||
// - Converts messages with Media to multipart content format (text + image_url parts)
|
||||
// - Preserves ToolCallID, ToolCalls, and ReasoningContent for all messages
|
||||
func SerializeMessages(messages []Message) []any {
|
||||
out := make([]any, 0, len(messages))
|
||||
for _, m := range messages {
|
||||
if len(m.Media) == 0 {
|
||||
out = append(out, openaiMessage{
|
||||
Role: m.Role,
|
||||
Content: m.Content,
|
||||
ReasoningContent: m.ReasoningContent,
|
||||
ToolCalls: m.ToolCalls,
|
||||
ToolCallID: m.ToolCallID,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
// Multipart content format for messages with media
|
||||
parts := make([]map[string]any, 0, 1+len(m.Media))
|
||||
if m.Content != "" {
|
||||
parts = append(parts, map[string]any{
|
||||
"type": "text",
|
||||
"text": m.Content,
|
||||
})
|
||||
}
|
||||
for _, mediaURL := range m.Media {
|
||||
if strings.HasPrefix(mediaURL, "data:image/") {
|
||||
parts = append(parts, map[string]any{
|
||||
"type": "image_url",
|
||||
"image_url": map[string]any{
|
||||
"url": mediaURL,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
msg := map[string]any{
|
||||
"role": m.Role,
|
||||
"content": parts,
|
||||
}
|
||||
if m.ToolCallID != "" {
|
||||
msg["tool_call_id"] = m.ToolCallID
|
||||
}
|
||||
if len(m.ToolCalls) > 0 {
|
||||
msg["tool_calls"] = m.ToolCalls
|
||||
}
|
||||
if m.ReasoningContent != "" {
|
||||
msg["reasoning_content"] = m.ReasoningContent
|
||||
}
|
||||
out = append(out, msg)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// --- Response parsing ---
|
||||
|
||||
// ParseResponse parses a JSON chat completion response body into an LLMResponse.
|
||||
func ParseResponse(body io.Reader) (*LLMResponse, error) {
|
||||
var apiResponse struct {
|
||||
Choices []struct {
|
||||
Message struct {
|
||||
Content string `json:"content"`
|
||||
ReasoningContent string `json:"reasoning_content"`
|
||||
Reasoning string `json:"reasoning"`
|
||||
ReasoningDetails []ReasoningDetail `json:"reasoning_details"`
|
||||
ToolCalls []struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Function *struct {
|
||||
Name string `json:"name"`
|
||||
Arguments json.RawMessage `json:"arguments"`
|
||||
} `json:"function"`
|
||||
ExtraContent *struct {
|
||||
Google *struct {
|
||||
ThoughtSignature string `json:"thought_signature"`
|
||||
} `json:"google"`
|
||||
} `json:"extra_content"`
|
||||
} `json:"tool_calls"`
|
||||
} `json:"message"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
} `json:"choices"`
|
||||
Usage *UsageInfo `json:"usage"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(body).Decode(&apiResponse); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode response: %w", err)
|
||||
}
|
||||
|
||||
if len(apiResponse.Choices) == 0 {
|
||||
return &LLMResponse{
|
||||
Content: "",
|
||||
FinishReason: "stop",
|
||||
}, nil
|
||||
}
|
||||
|
||||
choice := apiResponse.Choices[0]
|
||||
toolCalls := make([]ToolCall, 0, len(choice.Message.ToolCalls))
|
||||
for _, tc := range choice.Message.ToolCalls {
|
||||
arguments := make(map[string]any)
|
||||
name := ""
|
||||
|
||||
// Extract thought_signature from Gemini/Google-specific extra content
|
||||
thoughtSignature := ""
|
||||
if tc.ExtraContent != nil && tc.ExtraContent.Google != nil {
|
||||
thoughtSignature = tc.ExtraContent.Google.ThoughtSignature
|
||||
}
|
||||
|
||||
if tc.Function != nil {
|
||||
name = tc.Function.Name
|
||||
arguments = DecodeToolCallArguments(tc.Function.Arguments, name)
|
||||
}
|
||||
|
||||
toolCall := ToolCall{
|
||||
ID: tc.ID,
|
||||
Name: name,
|
||||
Arguments: arguments,
|
||||
ThoughtSignature: thoughtSignature,
|
||||
}
|
||||
|
||||
if thoughtSignature != "" {
|
||||
toolCall.ExtraContent = &ExtraContent{
|
||||
Google: &GoogleExtra{
|
||||
ThoughtSignature: thoughtSignature,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
toolCalls = append(toolCalls, toolCall)
|
||||
}
|
||||
|
||||
return &LLMResponse{
|
||||
Content: choice.Message.Content,
|
||||
ReasoningContent: choice.Message.ReasoningContent,
|
||||
Reasoning: choice.Message.Reasoning,
|
||||
ReasoningDetails: choice.Message.ReasoningDetails,
|
||||
ToolCalls: toolCalls,
|
||||
FinishReason: choice.FinishReason,
|
||||
Usage: apiResponse.Usage,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// DecodeToolCallArguments decodes a tool call's arguments from raw JSON.
|
||||
func DecodeToolCallArguments(raw json.RawMessage, name string) map[string]any {
|
||||
arguments := make(map[string]any)
|
||||
raw = bytes.TrimSpace(raw)
|
||||
if len(raw) == 0 || bytes.Equal(raw, []byte("null")) {
|
||||
return arguments
|
||||
}
|
||||
|
||||
var decoded any
|
||||
if err := json.Unmarshal(raw, &decoded); err != nil {
|
||||
log.Printf("common: failed to decode tool call arguments payload for %q: %v", name, err)
|
||||
arguments["raw"] = string(raw)
|
||||
return arguments
|
||||
}
|
||||
|
||||
switch v := decoded.(type) {
|
||||
case string:
|
||||
if strings.TrimSpace(v) == "" {
|
||||
return arguments
|
||||
}
|
||||
if err := json.Unmarshal([]byte(v), &arguments); err != nil {
|
||||
log.Printf("common: failed to decode tool call arguments for %q: %v", name, err)
|
||||
arguments["raw"] = v
|
||||
}
|
||||
return arguments
|
||||
case map[string]any:
|
||||
return v
|
||||
default:
|
||||
log.Printf("common: unsupported tool call arguments type for %q: %T", name, decoded)
|
||||
arguments["raw"] = string(raw)
|
||||
return arguments
|
||||
}
|
||||
}
|
||||
|
||||
// --- HTTP response helpers ---
|
||||
|
||||
// HandleErrorResponse reads a non-200 response body and returns an appropriate error.
|
||||
func HandleErrorResponse(resp *http.Response, apiBase string) error {
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
body, readErr := io.ReadAll(io.LimitReader(resp.Body, 256))
|
||||
if readErr != nil {
|
||||
return fmt.Errorf("failed to read response: %w", readErr)
|
||||
}
|
||||
if LooksLikeHTML(body, contentType) {
|
||||
return WrapHTMLResponseError(resp.StatusCode, body, contentType, apiBase)
|
||||
}
|
||||
return fmt.Errorf(
|
||||
"API request failed:\n Status: %d\n Body: %s",
|
||||
resp.StatusCode,
|
||||
ResponsePreview(body, 128),
|
||||
)
|
||||
}
|
||||
|
||||
// ReadAndParseResponse peeks at the response body to detect HTML errors,
|
||||
// then parses the JSON response into an LLMResponse.
|
||||
func ReadAndParseResponse(resp *http.Response, apiBase string) (*LLMResponse, error) {
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
reader := bufio.NewReader(resp.Body)
|
||||
prefix, err := reader.Peek(256)
|
||||
if err != nil && err != io.EOF && err != bufio.ErrBufferFull {
|
||||
return nil, fmt.Errorf("failed to inspect response: %w", err)
|
||||
}
|
||||
if LooksLikeHTML(prefix, contentType) {
|
||||
return nil, WrapHTMLResponseError(resp.StatusCode, prefix, contentType, apiBase)
|
||||
}
|
||||
out, err := ParseResponse(reader)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse JSON response: %w", err)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// LooksLikeHTML checks if the response body appears to be HTML.
|
||||
func LooksLikeHTML(body []byte, contentType string) bool {
|
||||
contentType = strings.ToLower(strings.TrimSpace(contentType))
|
||||
if strings.Contains(contentType, "text/html") || strings.Contains(contentType, "application/xhtml+xml") {
|
||||
return true
|
||||
}
|
||||
prefix := bytes.ToLower(leadingTrimmedPrefix(body, 128))
|
||||
return bytes.HasPrefix(prefix, []byte("<!doctype html")) ||
|
||||
bytes.HasPrefix(prefix, []byte("<html")) ||
|
||||
bytes.HasPrefix(prefix, []byte("<head")) ||
|
||||
bytes.HasPrefix(prefix, []byte("<body"))
|
||||
}
|
||||
|
||||
// WrapHTMLResponseError creates a descriptive error for HTML responses.
|
||||
func WrapHTMLResponseError(statusCode int, body []byte, contentType, apiBase string) error {
|
||||
respPreview := ResponsePreview(body, 128)
|
||||
return fmt.Errorf(
|
||||
"API request failed: %s returned HTML instead of JSON (content-type: %s); check api_base or proxy configuration.\n Status: %d\n Body: %s",
|
||||
apiBase,
|
||||
contentType,
|
||||
statusCode,
|
||||
respPreview,
|
||||
)
|
||||
}
|
||||
|
||||
// ResponsePreview returns a truncated preview of response body for error messages.
|
||||
func ResponsePreview(body []byte, maxLen int) string {
|
||||
trimmed := bytes.TrimSpace(body)
|
||||
if len(trimmed) == 0 {
|
||||
return "<empty>"
|
||||
}
|
||||
if len(trimmed) <= maxLen {
|
||||
return string(trimmed)
|
||||
}
|
||||
return string(trimmed[:maxLen]) + "..."
|
||||
}
|
||||
|
||||
func leadingTrimmedPrefix(body []byte, maxLen int) []byte {
|
||||
i := 0
|
||||
for i < len(body) {
|
||||
switch body[i] {
|
||||
case ' ', '\t', '\n', '\r', '\f', '\v':
|
||||
i++
|
||||
default:
|
||||
end := i + maxLen
|
||||
if end > len(body) {
|
||||
end = len(body)
|
||||
}
|
||||
return body[i:end]
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// --- Numeric helpers ---
|
||||
|
||||
// AsInt converts various numeric types to int.
|
||||
func AsInt(v any) (int, bool) {
|
||||
switch val := v.(type) {
|
||||
case int:
|
||||
return val, true
|
||||
case int64:
|
||||
return int(val), true
|
||||
case float64:
|
||||
return int(val), true
|
||||
case float32:
|
||||
return int(val), true
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
// AsFloat converts various numeric types to float64.
|
||||
func AsFloat(v any) (float64, bool) {
|
||||
switch val := v.(type) {
|
||||
case float64:
|
||||
return val, true
|
||||
case float32:
|
||||
return float64(val), true
|
||||
case int:
|
||||
return float64(val), true
|
||||
case int64:
|
||||
return float64(val), true
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,558 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
|
||||
)
|
||||
|
||||
// --- NewHTTPClient tests ---
|
||||
|
||||
func TestNewHTTPClient_DefaultTimeout(t *testing.T) {
|
||||
client := NewHTTPClient("")
|
||||
if client.Timeout != DefaultRequestTimeout {
|
||||
t.Errorf("timeout = %v, want %v", client.Timeout, DefaultRequestTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewHTTPClient_WithProxy(t *testing.T) {
|
||||
client := NewHTTPClient("http://127.0.0.1:8080")
|
||||
transport, ok := client.Transport.(*http.Transport)
|
||||
if !ok || transport == nil {
|
||||
t.Fatalf("expected http.Transport with proxy, got %T", client.Transport)
|
||||
}
|
||||
req := &http.Request{URL: &url.URL{Scheme: "https", Host: "api.example.com"}}
|
||||
gotProxy, err := transport.Proxy(req)
|
||||
if err != nil {
|
||||
t.Fatalf("proxy function error: %v", err)
|
||||
}
|
||||
if gotProxy == nil || gotProxy.String() != "http://127.0.0.1:8080" {
|
||||
t.Errorf("proxy = %v, want http://127.0.0.1:8080", gotProxy)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewHTTPClient_NoProxy(t *testing.T) {
|
||||
client := NewHTTPClient("")
|
||||
if client.Transport != nil {
|
||||
t.Errorf("expected nil transport without proxy, got %T", client.Transport)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewHTTPClient_InvalidProxy(t *testing.T) {
|
||||
// Should not panic, just log and return client without proxy
|
||||
client := NewHTTPClient("://bad-url")
|
||||
if client == nil {
|
||||
t.Fatal("expected non-nil client even with invalid proxy")
|
||||
}
|
||||
}
|
||||
|
||||
// --- SerializeMessages tests ---
|
||||
|
||||
func TestSerializeMessages_PlainText(t *testing.T) {
|
||||
messages := []Message{
|
||||
{Role: "user", Content: "hello"},
|
||||
{Role: "assistant", Content: "hi", ReasoningContent: "thinking..."},
|
||||
}
|
||||
result := SerializeMessages(messages)
|
||||
|
||||
data, _ := json.Marshal(result)
|
||||
var msgs []map[string]any
|
||||
json.Unmarshal(data, &msgs)
|
||||
|
||||
if msgs[0]["content"] != "hello" {
|
||||
t.Errorf("expected plain string content, got %v", msgs[0]["content"])
|
||||
}
|
||||
if msgs[1]["reasoning_content"] != "thinking..." {
|
||||
t.Errorf("reasoning_content not preserved, got %v", msgs[1]["reasoning_content"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestSerializeMessages_WithMedia(t *testing.T) {
|
||||
messages := []Message{
|
||||
{Role: "user", Content: "describe this", Media: []string{"data:image/png;base64,abc123"}},
|
||||
}
|
||||
result := SerializeMessages(messages)
|
||||
|
||||
data, _ := json.Marshal(result)
|
||||
var msgs []map[string]any
|
||||
json.Unmarshal(data, &msgs)
|
||||
|
||||
content, ok := msgs[0]["content"].([]any)
|
||||
if !ok {
|
||||
t.Fatalf("expected array content for media message, got %T", msgs[0]["content"])
|
||||
}
|
||||
if len(content) != 2 {
|
||||
t.Fatalf("expected 2 content parts, got %d", len(content))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSerializeMessages_MediaWithToolCallID(t *testing.T) {
|
||||
messages := []Message{
|
||||
{Role: "tool", Content: "result", Media: []string{"data:image/png;base64,xyz"}, ToolCallID: "call_1"},
|
||||
}
|
||||
result := SerializeMessages(messages)
|
||||
|
||||
data, _ := json.Marshal(result)
|
||||
var msgs []map[string]any
|
||||
json.Unmarshal(data, &msgs)
|
||||
|
||||
if msgs[0]["tool_call_id"] != "call_1" {
|
||||
t.Errorf("tool_call_id not preserved, got %v", msgs[0]["tool_call_id"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestSerializeMessages_StripsSystemParts(t *testing.T) {
|
||||
messages := []Message{
|
||||
{
|
||||
Role: "system",
|
||||
Content: "you are helpful",
|
||||
SystemParts: []protocoltypes.ContentBlock{
|
||||
{Type: "text", Text: "you are helpful"},
|
||||
},
|
||||
},
|
||||
}
|
||||
result := SerializeMessages(messages)
|
||||
|
||||
data, _ := json.Marshal(result)
|
||||
if strings.Contains(string(data), "system_parts") {
|
||||
t.Error("system_parts should not appear in serialized output")
|
||||
}
|
||||
}
|
||||
|
||||
// --- ParseResponse tests ---
|
||||
|
||||
func TestParseResponse_BasicContent(t *testing.T) {
|
||||
body := `{"choices":[{"message":{"content":"hello world"},"finish_reason":"stop"}]}`
|
||||
out, err := ParseResponse(strings.NewReader(body))
|
||||
if err != nil {
|
||||
t.Fatalf("ParseResponse() error = %v", err)
|
||||
}
|
||||
if out.Content != "hello world" {
|
||||
t.Errorf("Content = %q, want %q", out.Content, "hello world")
|
||||
}
|
||||
if out.FinishReason != "stop" {
|
||||
t.Errorf("FinishReason = %q, want %q", out.FinishReason, "stop")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseResponse_EmptyChoices(t *testing.T) {
|
||||
body := `{"choices":[]}`
|
||||
out, err := ParseResponse(strings.NewReader(body))
|
||||
if err != nil {
|
||||
t.Fatalf("ParseResponse() error = %v", err)
|
||||
}
|
||||
if out.Content != "" {
|
||||
t.Errorf("Content = %q, want empty", out.Content)
|
||||
}
|
||||
if out.FinishReason != "stop" {
|
||||
t.Errorf("FinishReason = %q, want %q", out.FinishReason, "stop")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseResponse_WithToolCalls(t *testing.T) {
|
||||
body := `{"choices":[{"message":{"content":"","tool_calls":[{"id":"call_1","type":"function","function":{"name":"get_weather","arguments":"{\"city\":\"SF\"}"}}]},"finish_reason":"tool_calls"}]}`
|
||||
out, err := ParseResponse(strings.NewReader(body))
|
||||
if err != nil {
|
||||
t.Fatalf("ParseResponse() error = %v", err)
|
||||
}
|
||||
if len(out.ToolCalls) != 1 {
|
||||
t.Fatalf("len(ToolCalls) = %d, want 1", len(out.ToolCalls))
|
||||
}
|
||||
if out.ToolCalls[0].Name != "get_weather" {
|
||||
t.Errorf("ToolCalls[0].Name = %q, want %q", out.ToolCalls[0].Name, "get_weather")
|
||||
}
|
||||
if out.ToolCalls[0].Arguments["city"] != "SF" {
|
||||
t.Errorf("ToolCalls[0].Arguments[city] = %v, want SF", out.ToolCalls[0].Arguments["city"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseResponse_WithUsage(t *testing.T) {
|
||||
body := `{"choices":[{"message":{"content":"ok"},"finish_reason":"stop"}],"usage":{"prompt_tokens":10,"completion_tokens":5,"total_tokens":15}}`
|
||||
out, err := ParseResponse(strings.NewReader(body))
|
||||
if err != nil {
|
||||
t.Fatalf("ParseResponse() error = %v", err)
|
||||
}
|
||||
if out.Usage == nil {
|
||||
t.Fatal("Usage is nil")
|
||||
}
|
||||
if out.Usage.PromptTokens != 10 {
|
||||
t.Errorf("PromptTokens = %d, want 10", out.Usage.PromptTokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseResponse_WithReasoningContent(t *testing.T) {
|
||||
body := `{"choices":[{"message":{"content":"2","reasoning_content":"Let me think... 1+1=2"},"finish_reason":"stop"}]}`
|
||||
out, err := ParseResponse(strings.NewReader(body))
|
||||
if err != nil {
|
||||
t.Fatalf("ParseResponse() error = %v", err)
|
||||
}
|
||||
if out.ReasoningContent != "Let me think... 1+1=2" {
|
||||
t.Errorf("ReasoningContent = %q, want %q", out.ReasoningContent, "Let me think... 1+1=2")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseResponse_InvalidJSON(t *testing.T) {
|
||||
_, err := ParseResponse(strings.NewReader("not json"))
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid JSON")
|
||||
}
|
||||
}
|
||||
|
||||
// --- DecodeToolCallArguments tests ---
|
||||
|
||||
func TestDecodeToolCallArguments_ObjectJSON(t *testing.T) {
|
||||
raw := json.RawMessage(`{"city":"Seattle","units":"metric"}`)
|
||||
args := DecodeToolCallArguments(raw, "test")
|
||||
if args["city"] != "Seattle" {
|
||||
t.Errorf("city = %v, want Seattle", args["city"])
|
||||
}
|
||||
if args["units"] != "metric" {
|
||||
t.Errorf("units = %v, want metric", args["units"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeToolCallArguments_StringJSON(t *testing.T) {
|
||||
raw := json.RawMessage(`"{\"city\":\"SF\"}"`)
|
||||
args := DecodeToolCallArguments(raw, "test")
|
||||
if args["city"] != "SF" {
|
||||
t.Errorf("city = %v, want SF", args["city"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeToolCallArguments_EmptyInput(t *testing.T) {
|
||||
args := DecodeToolCallArguments(nil, "test")
|
||||
if len(args) != 0 {
|
||||
t.Errorf("expected empty map, got %v", args)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeToolCallArguments_NullInput(t *testing.T) {
|
||||
args := DecodeToolCallArguments(json.RawMessage(`null`), "test")
|
||||
if len(args) != 0 {
|
||||
t.Errorf("expected empty map, got %v", args)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeToolCallArguments_InvalidJSON(t *testing.T) {
|
||||
args := DecodeToolCallArguments(json.RawMessage(`not-json`), "test")
|
||||
if _, ok := args["raw"]; !ok {
|
||||
t.Error("expected 'raw' fallback key for invalid JSON")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeToolCallArguments_EmptyStringJSON(t *testing.T) {
|
||||
args := DecodeToolCallArguments(json.RawMessage(`" "`), "test")
|
||||
if len(args) != 0 {
|
||||
t.Errorf("expected empty map for whitespace string, got %v", args)
|
||||
}
|
||||
}
|
||||
|
||||
// --- HandleErrorResponse tests ---
|
||||
|
||||
func TestHandleErrorResponse_JSONError(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
w.Write([]byte(`{"error":"bad request"}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
resp, err := http.Get(server.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("http.Get() error = %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
err = HandleErrorResponse(resp, server.URL)
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "400") {
|
||||
t.Errorf("error should contain status code, got %v", err)
|
||||
}
|
||||
if strings.Contains(err.Error(), "HTML") {
|
||||
t.Errorf("should not mention HTML for JSON error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleErrorResponse_HTMLError(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/html")
|
||||
w.WriteHeader(http.StatusBadGateway)
|
||||
w.Write([]byte("<!DOCTYPE html><html><body>bad gateway</body></html>"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
resp, err := http.Get(server.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("http.Get() error = %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
err = HandleErrorResponse(resp, server.URL)
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "HTML instead of JSON") {
|
||||
t.Errorf("expected HTML error message, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// --- ReadAndParseResponse tests ---
|
||||
|
||||
func TestReadAndParseResponse_ValidJSON(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write([]byte(`{"choices":[{"message":{"content":"ok"},"finish_reason":"stop"}]}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
resp, err := http.Get(server.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("http.Get() error = %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
out, err := ReadAndParseResponse(resp, server.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadAndParseResponse() error = %v", err)
|
||||
}
|
||||
if out.Content != "ok" {
|
||||
t.Errorf("Content = %q, want %q", out.Content, "ok")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadAndParseResponse_HTMLResponse(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/html")
|
||||
w.Write([]byte("<!DOCTYPE html><html><body>login page</body></html>"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
resp, err := http.Get(server.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("http.Get() error = %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
_, err = ReadAndParseResponse(resp, server.URL)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for HTML response")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "HTML instead of JSON") {
|
||||
t.Errorf("expected HTML error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// --- LooksLikeHTML tests ---
|
||||
|
||||
func TestLooksLikeHTML_ContentTypeHTML(t *testing.T) {
|
||||
if !LooksLikeHTML(nil, "text/html; charset=utf-8") {
|
||||
t.Error("expected true for text/html content type")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLooksLikeHTML_ContentTypeXHTML(t *testing.T) {
|
||||
if !LooksLikeHTML(nil, "application/xhtml+xml") {
|
||||
t.Error("expected true for xhtml content type")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLooksLikeHTML_BodyPrefix(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
}{
|
||||
{"doctype", "<!DOCTYPE html><html>"},
|
||||
{"html tag", "<html><body>"},
|
||||
{"head tag", "<head><title>"},
|
||||
{"body tag", "<body>content"},
|
||||
{"whitespace before", " \n\t<!DOCTYPE html>"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if !LooksLikeHTML([]byte(tt.body), "application/json") {
|
||||
t.Errorf("expected true for body %q", tt.body)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLooksLikeHTML_NotHTML(t *testing.T) {
|
||||
if LooksLikeHTML([]byte(`{"error":"bad"}`), "application/json") {
|
||||
t.Error("expected false for JSON body")
|
||||
}
|
||||
}
|
||||
|
||||
// --- ResponsePreview tests ---
|
||||
|
||||
func TestResponsePreview_Short(t *testing.T) {
|
||||
got := ResponsePreview([]byte("hello"), 128)
|
||||
if got != "hello" {
|
||||
t.Errorf("got %q, want %q", got, "hello")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResponsePreview_Truncated(t *testing.T) {
|
||||
body := strings.Repeat("a", 200)
|
||||
got := ResponsePreview([]byte(body), 128)
|
||||
if len(got) != 131 { // 128 + "..."
|
||||
t.Errorf("len = %d, want 131", len(got))
|
||||
}
|
||||
if !strings.HasSuffix(got, "...") {
|
||||
t.Error("expected ... suffix")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResponsePreview_Empty(t *testing.T) {
|
||||
got := ResponsePreview([]byte(""), 128)
|
||||
if got != "<empty>" {
|
||||
t.Errorf("got %q, want %q", got, "<empty>")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResponsePreview_Whitespace(t *testing.T) {
|
||||
got := ResponsePreview([]byte(" \n\t "), 128)
|
||||
if got != "<empty>" {
|
||||
t.Errorf("got %q, want %q for whitespace-only body", got, "<empty>")
|
||||
}
|
||||
}
|
||||
|
||||
// --- AsInt tests ---
|
||||
|
||||
func TestAsInt(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
val any
|
||||
want int
|
||||
ok bool
|
||||
}{
|
||||
{"int", 42, 42, true},
|
||||
{"int64", int64(99), 99, true},
|
||||
{"float64", float64(512), 512, true},
|
||||
{"float32", float32(256), 256, true},
|
||||
{"string", "nope", 0, false},
|
||||
{"nil", nil, 0, false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, ok := AsInt(tt.val)
|
||||
if ok != tt.ok || got != tt.want {
|
||||
t.Errorf("AsInt(%v) = (%d, %v), want (%d, %v)", tt.val, got, ok, tt.want, tt.ok)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// --- AsFloat tests ---
|
||||
|
||||
func TestAsFloat(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
val any
|
||||
want float64
|
||||
ok bool
|
||||
}{
|
||||
{"float64", float64(0.7), 0.7, true},
|
||||
{"float32", float32(0.5), float64(float32(0.5)), true},
|
||||
{"int", 1, 1.0, true},
|
||||
{"int64", int64(100), 100.0, true},
|
||||
{"string", "nope", 0, false},
|
||||
{"nil", nil, 0, false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, ok := AsFloat(tt.val)
|
||||
if ok != tt.ok || got != tt.want {
|
||||
t.Errorf("AsFloat(%v) = (%f, %v), want (%f, %v)", tt.val, got, ok, tt.want, tt.ok)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// --- WrapHTMLResponseError tests ---
|
||||
|
||||
func TestWrapHTMLResponseError(t *testing.T) {
|
||||
err := WrapHTMLResponseError(502, []byte("<html>bad</html>"), "text/html", "https://api.example.com")
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
msg := err.Error()
|
||||
if !strings.Contains(msg, "502") {
|
||||
t.Errorf("expected status code in error, got %v", msg)
|
||||
}
|
||||
if !strings.Contains(msg, "https://api.example.com") {
|
||||
t.Errorf("expected api base in error, got %v", msg)
|
||||
}
|
||||
if !strings.Contains(msg, "HTML instead of JSON") {
|
||||
t.Errorf("expected HTML mention in error, got %v", msg)
|
||||
}
|
||||
}
|
||||
|
||||
// --- HandleErrorResponse with read failure ---
|
||||
|
||||
func TestHandleErrorResponse_EmptyBody(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
// empty body
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
resp, err := http.Get(server.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("http.Get() error = %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
err = HandleErrorResponse(resp, server.URL)
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "500") {
|
||||
t.Errorf("expected status code, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// --- ReadAndParseResponse with invalid JSON ---
|
||||
|
||||
func TestReadAndParseResponse_InvalidJSON(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write([]byte("not valid json"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
resp, err := http.Get(server.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("http.Get() error = %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
_, err = ReadAndParseResponse(resp, server.URL)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid JSON")
|
||||
}
|
||||
}
|
||||
|
||||
// --- ParseResponse with thought_signature (Google/Gemini) ---
|
||||
|
||||
func TestParseResponse_WithThoughtSignature(t *testing.T) {
|
||||
body := `{"choices":[{"message":{"content":"","tool_calls":[{"id":"call_1","type":"function","function":{"name":"test_tool","arguments":"{}"},"extra_content":{"google":{"thought_signature":"sig123"}}}]},"finish_reason":"tool_calls"}]}`
|
||||
out, err := ParseResponse(strings.NewReader(body))
|
||||
if err != nil {
|
||||
t.Fatalf("ParseResponse() error = %v", err)
|
||||
}
|
||||
if len(out.ToolCalls) != 1 {
|
||||
t.Fatalf("len(ToolCalls) = %d, want 1", len(out.ToolCalls))
|
||||
}
|
||||
if out.ToolCalls[0].ThoughtSignature != "sig123" {
|
||||
t.Errorf("ThoughtSignature = %q, want %q", out.ToolCalls[0].ThoughtSignature, "sig123")
|
||||
}
|
||||
if out.ToolCalls[0].ExtraContent == nil || out.ToolCalls[0].ExtraContent.Google == nil {
|
||||
t.Fatal("ExtraContent.Google is nil")
|
||||
}
|
||||
if out.ToolCalls[0].ExtraContent.Google.ThoughtSignature != "sig123" {
|
||||
t.Errorf("ExtraContent.Google.ThoughtSignature = %q, want %q",
|
||||
out.ToolCalls[0].ExtraContent.Google.ThoughtSignature, "sig123")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user