mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
feat(provider): add support for azure openai provider (#1422)
* Add support for azure openai provider * Add checks for deployment model name * Apply suggestion from @Copilot Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Addressing @Copilot suggestion to remove the init() function which seemed redundant * Fix readme * Fix linting checks --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -384,6 +384,15 @@ func DefaultConfig() *Config {
|
||||
APIBase: "http://localhost:8000/v1",
|
||||
APIKey: "",
|
||||
},
|
||||
|
||||
// Azure OpenAI - https://portal.azure.com
|
||||
// model_name is a user-friendly alias; the model field's path after "azure/" is your deployment name
|
||||
{
|
||||
ModelName: "azure-gpt5",
|
||||
Model: "azure/my-gpt5-deployment",
|
||||
APIBase: "https://your-resource.openai.azure.com",
|
||||
APIKey: "",
|
||||
},
|
||||
},
|
||||
Gateway: GatewayConfig{
|
||||
Host: "127.0.0.1",
|
||||
|
||||
@@ -0,0 +1,150 @@
|
||||
package azure
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/providers/common"
|
||||
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
|
||||
)
|
||||
|
||||
type (
|
||||
LLMResponse = protocoltypes.LLMResponse
|
||||
Message = protocoltypes.Message
|
||||
ToolDefinition = protocoltypes.ToolDefinition
|
||||
)
|
||||
|
||||
const (
|
||||
// azureAPIVersion is the Azure OpenAI API version used for all requests.
|
||||
azureAPIVersion = "2024-10-21"
|
||||
defaultRequestTimeout = common.DefaultRequestTimeout
|
||||
)
|
||||
|
||||
// Provider implements the LLM provider interface for Azure OpenAI endpoints.
|
||||
// It handles Azure-specific authentication (api-key header), URL construction
|
||||
// (deployment-based), and request body formatting (max_completion_tokens, no model field).
|
||||
type Provider struct {
|
||||
apiKey string
|
||||
apiBase string
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// Option configures the Azure Provider.
|
||||
type Option func(*Provider)
|
||||
|
||||
// WithRequestTimeout sets the HTTP request timeout.
|
||||
func WithRequestTimeout(timeout time.Duration) Option {
|
||||
return func(p *Provider) {
|
||||
if timeout > 0 {
|
||||
p.httpClient.Timeout = timeout
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// NewProvider creates a new Azure OpenAI provider.
|
||||
func NewProvider(apiKey, apiBase, proxy string, opts ...Option) *Provider {
|
||||
p := &Provider{
|
||||
apiKey: apiKey,
|
||||
apiBase: strings.TrimRight(apiBase, "/"),
|
||||
httpClient: common.NewHTTPClient(proxy),
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
if opt != nil {
|
||||
opt(p)
|
||||
}
|
||||
}
|
||||
|
||||
return p
|
||||
}
|
||||
|
||||
// NewProviderWithTimeout creates a new Azure OpenAI provider with a custom request timeout in seconds.
|
||||
func NewProviderWithTimeout(apiKey, apiBase, proxy string, requestTimeoutSeconds int) *Provider {
|
||||
return NewProvider(
|
||||
apiKey, apiBase, proxy,
|
||||
WithRequestTimeout(time.Duration(requestTimeoutSeconds)*time.Second),
|
||||
)
|
||||
}
|
||||
|
||||
// Chat sends a chat completion request to the Azure OpenAI endpoint.
|
||||
// The model parameter is used as the Azure deployment name in the URL.
|
||||
func (p *Provider) Chat(
|
||||
ctx context.Context,
|
||||
messages []Message,
|
||||
tools []ToolDefinition,
|
||||
model string,
|
||||
options map[string]any,
|
||||
) (*LLMResponse, error) {
|
||||
if p.apiBase == "" {
|
||||
return nil, fmt.Errorf("Azure API base not configured")
|
||||
}
|
||||
|
||||
// model is the deployment name for Azure OpenAI
|
||||
deployment := model
|
||||
|
||||
// Build Azure-specific URL safely using url.JoinPath and query encoding
|
||||
// to prevent path traversal or query injection via deployment names.
|
||||
base, err := url.JoinPath(p.apiBase, "openai/deployments", deployment, "chat/completions")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to build Azure request URL: %w", err)
|
||||
}
|
||||
requestURL := base + "?api-version=" + azureAPIVersion
|
||||
|
||||
// Build request body — no "model" field (Azure infers from deployment URL)
|
||||
requestBody := map[string]any{
|
||||
"messages": common.SerializeMessages(messages),
|
||||
}
|
||||
|
||||
if len(tools) > 0 {
|
||||
requestBody["tools"] = tools
|
||||
requestBody["tool_choice"] = "auto"
|
||||
}
|
||||
|
||||
// Azure OpenAI always uses max_completion_tokens
|
||||
if maxTokens, ok := common.AsInt(options["max_tokens"]); ok {
|
||||
requestBody["max_completion_tokens"] = maxTokens
|
||||
}
|
||||
|
||||
if temperature, ok := common.AsFloat(options["temperature"]); ok {
|
||||
requestBody["temperature"] = temperature
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(requestBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", requestURL, bytes.NewReader(jsonData))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
// Azure uses api-key header instead of Authorization: Bearer
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
if p.apiKey != "" {
|
||||
req.Header.Set("api-key", p.apiKey)
|
||||
}
|
||||
|
||||
resp, err := p.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, common.HandleErrorResponse(resp, p.apiBase)
|
||||
}
|
||||
|
||||
return common.ReadAndParseResponse(resp, p.apiBase)
|
||||
}
|
||||
|
||||
// GetDefaultModel returns an empty string as Azure deployments are user-configured.
|
||||
func (p *Provider) GetDefaultModel() string {
|
||||
return ""
|
||||
}
|
||||
@@ -0,0 +1,232 @@
|
||||
package azure
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// writeValidResponse writes a minimal valid Azure OpenAI chat completion response.
|
||||
func writeValidResponse(w http.ResponseWriter) {
|
||||
resp := map[string]any{
|
||||
"choices": []map[string]any{
|
||||
{
|
||||
"message": map[string]any{"content": "ok"},
|
||||
"finish_reason": "stop",
|
||||
},
|
||||
},
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}
|
||||
|
||||
func TestProviderChat_AzureURLConstruction(t *testing.T) {
|
||||
var capturedPath string
|
||||
var capturedAPIVersion string
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
capturedPath = r.URL.Path
|
||||
capturedAPIVersion = r.URL.Query().Get("api-version")
|
||||
writeValidResponse(w)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
p := NewProvider("test-key", server.URL, "")
|
||||
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "my-gpt5-deployment", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error = %v", err)
|
||||
}
|
||||
|
||||
wantPath := "/openai/deployments/my-gpt5-deployment/chat/completions"
|
||||
if capturedPath != wantPath {
|
||||
t.Errorf("URL path = %q, want %q", capturedPath, wantPath)
|
||||
}
|
||||
if capturedAPIVersion != azureAPIVersion {
|
||||
t.Errorf("api-version = %q, want %q", capturedAPIVersion, azureAPIVersion)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderChat_AzureAuthHeader(t *testing.T) {
|
||||
var capturedAPIKey string
|
||||
var capturedAuth string
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
capturedAPIKey = r.Header.Get("api-key")
|
||||
capturedAuth = r.Header.Get("Authorization")
|
||||
writeValidResponse(w)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
p := NewProvider("test-azure-key", server.URL, "")
|
||||
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "deployment", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error = %v", err)
|
||||
}
|
||||
|
||||
if capturedAPIKey != "test-azure-key" {
|
||||
t.Errorf("api-key header = %q, want %q", capturedAPIKey, "test-azure-key")
|
||||
}
|
||||
if capturedAuth != "" {
|
||||
t.Errorf("Authorization header should be empty, got %q", capturedAuth)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderChat_AzureOmitsModelFromBody(t *testing.T) {
|
||||
var requestBody map[string]any
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
json.NewDecoder(r.Body).Decode(&requestBody)
|
||||
writeValidResponse(w)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
p := NewProvider("test-key", server.URL, "")
|
||||
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "deployment", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error = %v", err)
|
||||
}
|
||||
|
||||
if _, exists := requestBody["model"]; exists {
|
||||
t.Error("request body should not contain 'model' field for Azure OpenAI")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderChat_AzureUsesMaxCompletionTokens(t *testing.T) {
|
||||
var requestBody map[string]any
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
json.NewDecoder(r.Body).Decode(&requestBody)
|
||||
writeValidResponse(w)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
p := NewProvider("test-key", server.URL, "")
|
||||
_, err := p.Chat(
|
||||
t.Context(),
|
||||
[]Message{{Role: "user", Content: "hi"}},
|
||||
nil,
|
||||
"deployment",
|
||||
map[string]any{"max_tokens": 2048},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error = %v", err)
|
||||
}
|
||||
|
||||
if _, exists := requestBody["max_completion_tokens"]; !exists {
|
||||
t.Error("request body should contain 'max_completion_tokens'")
|
||||
}
|
||||
if _, exists := requestBody["max_tokens"]; exists {
|
||||
t.Error("request body should not contain 'max_tokens'")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderChat_AzureHTTPError(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, `{"error":"unauthorized"}`, http.StatusUnauthorized)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
p := NewProvider("bad-key", server.URL, "")
|
||||
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "deployment", nil)
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderChat_AzureParseToolCalls(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
resp := map[string]any{
|
||||
"choices": []map[string]any{
|
||||
{
|
||||
"message": map[string]any{
|
||||
"content": "",
|
||||
"tool_calls": []map[string]any{
|
||||
{
|
||||
"id": "call_1",
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"name": "get_weather",
|
||||
"arguments": `{"city":"Seattle"}`,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"finish_reason": "tool_calls",
|
||||
},
|
||||
},
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
p := NewProvider("test-key", server.URL, "")
|
||||
out, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "weather?"}}, nil, "deployment", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error = %v", err)
|
||||
}
|
||||
|
||||
if len(out.ToolCalls) != 1 {
|
||||
t.Fatalf("len(ToolCalls) = %d, want 1", len(out.ToolCalls))
|
||||
}
|
||||
if out.ToolCalls[0].Name != "get_weather" {
|
||||
t.Errorf("ToolCalls[0].Name = %q, want %q", out.ToolCalls[0].Name, "get_weather")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProvider_AzureEmptyAPIBase(t *testing.T) {
|
||||
p := NewProvider("test-key", "", "")
|
||||
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "deployment", nil)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for empty API base")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProvider_AzureRequestTimeoutDefault(t *testing.T) {
|
||||
p := NewProvider("test-key", "https://example.com", "")
|
||||
if p.httpClient.Timeout != defaultRequestTimeout {
|
||||
t.Errorf("timeout = %v, want %v", p.httpClient.Timeout, defaultRequestTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProvider_AzureRequestTimeoutOverride(t *testing.T) {
|
||||
p := NewProvider("test-key", "https://example.com", "", WithRequestTimeout(300*time.Second))
|
||||
if p.httpClient.Timeout != 300*time.Second {
|
||||
t.Errorf("timeout = %v, want %v", p.httpClient.Timeout, 300*time.Second)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProvider_AzureNewProviderWithTimeout(t *testing.T) {
|
||||
p := NewProviderWithTimeout("test-key", "https://example.com", "", 180)
|
||||
if p.httpClient.Timeout != 180*time.Second {
|
||||
t.Errorf("timeout = %v, want %v", p.httpClient.Timeout, 180*time.Second)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderChat_AzureDeploymentNameEscaped(t *testing.T) {
|
||||
var capturedPath string
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
capturedPath = r.URL.RawPath // use RawPath to see percent-encoding
|
||||
if capturedPath == "" {
|
||||
capturedPath = r.URL.Path
|
||||
}
|
||||
writeValidResponse(w)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
p := NewProvider("test-key", server.URL, "")
|
||||
|
||||
// Deployment name with characters that could cause path injection
|
||||
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "my deploy/../../admin", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error = %v", err)
|
||||
}
|
||||
|
||||
// The slash and special chars in the deployment name must be escaped, not treated as path separators
|
||||
if capturedPath == "/openai/deployments/my deploy/../../admin/chat/completions" {
|
||||
t.Fatal("deployment name was interpolated without escaping — path injection possible")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,380 @@
|
||||
// PicoClaw - Ultra-lightweight personal AI agent
|
||||
// License: MIT
|
||||
//
|
||||
// Copyright (c) 2026 PicoClaw contributors
|
||||
|
||||
// Package common provides shared utilities used by multiple LLM provider
|
||||
// implementations (openai_compat, azure, etc.).
|
||||
package common
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
|
||||
)
|
||||
|
||||
// Re-export protocol types used across providers.
|
||||
type (
|
||||
ToolCall = protocoltypes.ToolCall
|
||||
FunctionCall = protocoltypes.FunctionCall
|
||||
LLMResponse = protocoltypes.LLMResponse
|
||||
UsageInfo = protocoltypes.UsageInfo
|
||||
Message = protocoltypes.Message
|
||||
ToolDefinition = protocoltypes.ToolDefinition
|
||||
ToolFunctionDefinition = protocoltypes.ToolFunctionDefinition
|
||||
ExtraContent = protocoltypes.ExtraContent
|
||||
GoogleExtra = protocoltypes.GoogleExtra
|
||||
ReasoningDetail = protocoltypes.ReasoningDetail
|
||||
)
|
||||
|
||||
const DefaultRequestTimeout = 120 * time.Second
|
||||
|
||||
// NewHTTPClient creates an *http.Client with an optional proxy and the default timeout.
|
||||
func NewHTTPClient(proxy string) *http.Client {
|
||||
client := &http.Client{
|
||||
Timeout: DefaultRequestTimeout,
|
||||
}
|
||||
if proxy != "" {
|
||||
parsed, err := url.Parse(proxy)
|
||||
if err == nil {
|
||||
// Preserve http.DefaultTransport settings (TLS, HTTP/2, timeouts, etc.)
|
||||
if base, ok := http.DefaultTransport.(*http.Transport); ok {
|
||||
tr := base.Clone()
|
||||
tr.Proxy = http.ProxyURL(parsed)
|
||||
client.Transport = tr
|
||||
} else {
|
||||
// Fallback: minimal transport if DefaultTransport is not *http.Transport.
|
||||
client.Transport = &http.Transport{
|
||||
Proxy: http.ProxyURL(parsed),
|
||||
}
|
||||
}
|
||||
} else {
|
||||
log.Printf("common: invalid proxy URL %q: %v", proxy, err)
|
||||
}
|
||||
}
|
||||
return client
|
||||
}
|
||||
|
||||
// --- Message serialization ---
|
||||
|
||||
// openaiMessage is the wire-format message for OpenAI-compatible APIs.
|
||||
// It mirrors protocoltypes.Message but omits SystemParts, which is an
|
||||
// internal field that would be unknown to third-party endpoints.
|
||||
type openaiMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
ReasoningContent string `json:"reasoning_content,omitempty"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
}
|
||||
|
||||
// SerializeMessages converts internal Message structs to the OpenAI wire format.
|
||||
// - Strips SystemParts (unknown to third-party endpoints)
|
||||
// - Converts messages with Media to multipart content format (text + image_url parts)
|
||||
// - Preserves ToolCallID, ToolCalls, and ReasoningContent for all messages
|
||||
func SerializeMessages(messages []Message) []any {
|
||||
out := make([]any, 0, len(messages))
|
||||
for _, m := range messages {
|
||||
if len(m.Media) == 0 {
|
||||
out = append(out, openaiMessage{
|
||||
Role: m.Role,
|
||||
Content: m.Content,
|
||||
ReasoningContent: m.ReasoningContent,
|
||||
ToolCalls: m.ToolCalls,
|
||||
ToolCallID: m.ToolCallID,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
// Multipart content format for messages with media
|
||||
parts := make([]map[string]any, 0, 1+len(m.Media))
|
||||
if m.Content != "" {
|
||||
parts = append(parts, map[string]any{
|
||||
"type": "text",
|
||||
"text": m.Content,
|
||||
})
|
||||
}
|
||||
for _, mediaURL := range m.Media {
|
||||
if strings.HasPrefix(mediaURL, "data:image/") {
|
||||
parts = append(parts, map[string]any{
|
||||
"type": "image_url",
|
||||
"image_url": map[string]any{
|
||||
"url": mediaURL,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
msg := map[string]any{
|
||||
"role": m.Role,
|
||||
"content": parts,
|
||||
}
|
||||
if m.ToolCallID != "" {
|
||||
msg["tool_call_id"] = m.ToolCallID
|
||||
}
|
||||
if len(m.ToolCalls) > 0 {
|
||||
msg["tool_calls"] = m.ToolCalls
|
||||
}
|
||||
if m.ReasoningContent != "" {
|
||||
msg["reasoning_content"] = m.ReasoningContent
|
||||
}
|
||||
out = append(out, msg)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// --- Response parsing ---
|
||||
|
||||
// ParseResponse parses a JSON chat completion response body into an LLMResponse.
|
||||
func ParseResponse(body io.Reader) (*LLMResponse, error) {
|
||||
var apiResponse struct {
|
||||
Choices []struct {
|
||||
Message struct {
|
||||
Content string `json:"content"`
|
||||
ReasoningContent string `json:"reasoning_content"`
|
||||
Reasoning string `json:"reasoning"`
|
||||
ReasoningDetails []ReasoningDetail `json:"reasoning_details"`
|
||||
ToolCalls []struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Function *struct {
|
||||
Name string `json:"name"`
|
||||
Arguments json.RawMessage `json:"arguments"`
|
||||
} `json:"function"`
|
||||
ExtraContent *struct {
|
||||
Google *struct {
|
||||
ThoughtSignature string `json:"thought_signature"`
|
||||
} `json:"google"`
|
||||
} `json:"extra_content"`
|
||||
} `json:"tool_calls"`
|
||||
} `json:"message"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
} `json:"choices"`
|
||||
Usage *UsageInfo `json:"usage"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(body).Decode(&apiResponse); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode response: %w", err)
|
||||
}
|
||||
|
||||
if len(apiResponse.Choices) == 0 {
|
||||
return &LLMResponse{
|
||||
Content: "",
|
||||
FinishReason: "stop",
|
||||
}, nil
|
||||
}
|
||||
|
||||
choice := apiResponse.Choices[0]
|
||||
toolCalls := make([]ToolCall, 0, len(choice.Message.ToolCalls))
|
||||
for _, tc := range choice.Message.ToolCalls {
|
||||
arguments := make(map[string]any)
|
||||
name := ""
|
||||
|
||||
// Extract thought_signature from Gemini/Google-specific extra content
|
||||
thoughtSignature := ""
|
||||
if tc.ExtraContent != nil && tc.ExtraContent.Google != nil {
|
||||
thoughtSignature = tc.ExtraContent.Google.ThoughtSignature
|
||||
}
|
||||
|
||||
if tc.Function != nil {
|
||||
name = tc.Function.Name
|
||||
arguments = DecodeToolCallArguments(tc.Function.Arguments, name)
|
||||
}
|
||||
|
||||
toolCall := ToolCall{
|
||||
ID: tc.ID,
|
||||
Name: name,
|
||||
Arguments: arguments,
|
||||
ThoughtSignature: thoughtSignature,
|
||||
}
|
||||
|
||||
if thoughtSignature != "" {
|
||||
toolCall.ExtraContent = &ExtraContent{
|
||||
Google: &GoogleExtra{
|
||||
ThoughtSignature: thoughtSignature,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
toolCalls = append(toolCalls, toolCall)
|
||||
}
|
||||
|
||||
return &LLMResponse{
|
||||
Content: choice.Message.Content,
|
||||
ReasoningContent: choice.Message.ReasoningContent,
|
||||
Reasoning: choice.Message.Reasoning,
|
||||
ReasoningDetails: choice.Message.ReasoningDetails,
|
||||
ToolCalls: toolCalls,
|
||||
FinishReason: choice.FinishReason,
|
||||
Usage: apiResponse.Usage,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// DecodeToolCallArguments decodes a tool call's arguments from raw JSON.
|
||||
func DecodeToolCallArguments(raw json.RawMessage, name string) map[string]any {
|
||||
arguments := make(map[string]any)
|
||||
raw = bytes.TrimSpace(raw)
|
||||
if len(raw) == 0 || bytes.Equal(raw, []byte("null")) {
|
||||
return arguments
|
||||
}
|
||||
|
||||
var decoded any
|
||||
if err := json.Unmarshal(raw, &decoded); err != nil {
|
||||
log.Printf("common: failed to decode tool call arguments payload for %q: %v", name, err)
|
||||
arguments["raw"] = string(raw)
|
||||
return arguments
|
||||
}
|
||||
|
||||
switch v := decoded.(type) {
|
||||
case string:
|
||||
if strings.TrimSpace(v) == "" {
|
||||
return arguments
|
||||
}
|
||||
if err := json.Unmarshal([]byte(v), &arguments); err != nil {
|
||||
log.Printf("common: failed to decode tool call arguments for %q: %v", name, err)
|
||||
arguments["raw"] = v
|
||||
}
|
||||
return arguments
|
||||
case map[string]any:
|
||||
return v
|
||||
default:
|
||||
log.Printf("common: unsupported tool call arguments type for %q: %T", name, decoded)
|
||||
arguments["raw"] = string(raw)
|
||||
return arguments
|
||||
}
|
||||
}
|
||||
|
||||
// --- HTTP response helpers ---
|
||||
|
||||
// HandleErrorResponse reads a non-200 response body and returns an appropriate error.
|
||||
func HandleErrorResponse(resp *http.Response, apiBase string) error {
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
body, readErr := io.ReadAll(io.LimitReader(resp.Body, 256))
|
||||
if readErr != nil {
|
||||
return fmt.Errorf("failed to read response: %w", readErr)
|
||||
}
|
||||
if LooksLikeHTML(body, contentType) {
|
||||
return WrapHTMLResponseError(resp.StatusCode, body, contentType, apiBase)
|
||||
}
|
||||
return fmt.Errorf(
|
||||
"API request failed:\n Status: %d\n Body: %s",
|
||||
resp.StatusCode,
|
||||
ResponsePreview(body, 128),
|
||||
)
|
||||
}
|
||||
|
||||
// ReadAndParseResponse peeks at the response body to detect HTML errors,
|
||||
// then parses the JSON response into an LLMResponse.
|
||||
func ReadAndParseResponse(resp *http.Response, apiBase string) (*LLMResponse, error) {
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
reader := bufio.NewReader(resp.Body)
|
||||
prefix, err := reader.Peek(256)
|
||||
if err != nil && err != io.EOF && err != bufio.ErrBufferFull {
|
||||
return nil, fmt.Errorf("failed to inspect response: %w", err)
|
||||
}
|
||||
if LooksLikeHTML(prefix, contentType) {
|
||||
return nil, WrapHTMLResponseError(resp.StatusCode, prefix, contentType, apiBase)
|
||||
}
|
||||
out, err := ParseResponse(reader)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse JSON response: %w", err)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// LooksLikeHTML checks if the response body appears to be HTML.
|
||||
func LooksLikeHTML(body []byte, contentType string) bool {
|
||||
contentType = strings.ToLower(strings.TrimSpace(contentType))
|
||||
if strings.Contains(contentType, "text/html") || strings.Contains(contentType, "application/xhtml+xml") {
|
||||
return true
|
||||
}
|
||||
prefix := bytes.ToLower(leadingTrimmedPrefix(body, 128))
|
||||
return bytes.HasPrefix(prefix, []byte("<!doctype html")) ||
|
||||
bytes.HasPrefix(prefix, []byte("<html")) ||
|
||||
bytes.HasPrefix(prefix, []byte("<head")) ||
|
||||
bytes.HasPrefix(prefix, []byte("<body"))
|
||||
}
|
||||
|
||||
// WrapHTMLResponseError creates a descriptive error for HTML responses.
|
||||
func WrapHTMLResponseError(statusCode int, body []byte, contentType, apiBase string) error {
|
||||
respPreview := ResponsePreview(body, 128)
|
||||
return fmt.Errorf(
|
||||
"API request failed: %s returned HTML instead of JSON (content-type: %s); check api_base or proxy configuration.\n Status: %d\n Body: %s",
|
||||
apiBase,
|
||||
contentType,
|
||||
statusCode,
|
||||
respPreview,
|
||||
)
|
||||
}
|
||||
|
||||
// ResponsePreview returns a truncated preview of response body for error messages.
|
||||
func ResponsePreview(body []byte, maxLen int) string {
|
||||
trimmed := bytes.TrimSpace(body)
|
||||
if len(trimmed) == 0 {
|
||||
return "<empty>"
|
||||
}
|
||||
if len(trimmed) <= maxLen {
|
||||
return string(trimmed)
|
||||
}
|
||||
return string(trimmed[:maxLen]) + "..."
|
||||
}
|
||||
|
||||
func leadingTrimmedPrefix(body []byte, maxLen int) []byte {
|
||||
i := 0
|
||||
for i < len(body) {
|
||||
switch body[i] {
|
||||
case ' ', '\t', '\n', '\r', '\f', '\v':
|
||||
i++
|
||||
default:
|
||||
end := i + maxLen
|
||||
if end > len(body) {
|
||||
end = len(body)
|
||||
}
|
||||
return body[i:end]
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// --- Numeric helpers ---
|
||||
|
||||
// AsInt converts various numeric types to int.
|
||||
func AsInt(v any) (int, bool) {
|
||||
switch val := v.(type) {
|
||||
case int:
|
||||
return val, true
|
||||
case int64:
|
||||
return int(val), true
|
||||
case float64:
|
||||
return int(val), true
|
||||
case float32:
|
||||
return int(val), true
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
// AsFloat converts various numeric types to float64.
|
||||
func AsFloat(v any) (float64, bool) {
|
||||
switch val := v.(type) {
|
||||
case float64:
|
||||
return val, true
|
||||
case float32:
|
||||
return float64(val), true
|
||||
case int:
|
||||
return float64(val), true
|
||||
case int64:
|
||||
return float64(val), true
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,558 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
|
||||
)
|
||||
|
||||
// --- NewHTTPClient tests ---
|
||||
|
||||
func TestNewHTTPClient_DefaultTimeout(t *testing.T) {
|
||||
client := NewHTTPClient("")
|
||||
if client.Timeout != DefaultRequestTimeout {
|
||||
t.Errorf("timeout = %v, want %v", client.Timeout, DefaultRequestTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewHTTPClient_WithProxy(t *testing.T) {
|
||||
client := NewHTTPClient("http://127.0.0.1:8080")
|
||||
transport, ok := client.Transport.(*http.Transport)
|
||||
if !ok || transport == nil {
|
||||
t.Fatalf("expected http.Transport with proxy, got %T", client.Transport)
|
||||
}
|
||||
req := &http.Request{URL: &url.URL{Scheme: "https", Host: "api.example.com"}}
|
||||
gotProxy, err := transport.Proxy(req)
|
||||
if err != nil {
|
||||
t.Fatalf("proxy function error: %v", err)
|
||||
}
|
||||
if gotProxy == nil || gotProxy.String() != "http://127.0.0.1:8080" {
|
||||
t.Errorf("proxy = %v, want http://127.0.0.1:8080", gotProxy)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewHTTPClient_NoProxy(t *testing.T) {
|
||||
client := NewHTTPClient("")
|
||||
if client.Transport != nil {
|
||||
t.Errorf("expected nil transport without proxy, got %T", client.Transport)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewHTTPClient_InvalidProxy(t *testing.T) {
|
||||
// Should not panic, just log and return client without proxy
|
||||
client := NewHTTPClient("://bad-url")
|
||||
if client == nil {
|
||||
t.Fatal("expected non-nil client even with invalid proxy")
|
||||
}
|
||||
}
|
||||
|
||||
// --- SerializeMessages tests ---
|
||||
|
||||
func TestSerializeMessages_PlainText(t *testing.T) {
|
||||
messages := []Message{
|
||||
{Role: "user", Content: "hello"},
|
||||
{Role: "assistant", Content: "hi", ReasoningContent: "thinking..."},
|
||||
}
|
||||
result := SerializeMessages(messages)
|
||||
|
||||
data, _ := json.Marshal(result)
|
||||
var msgs []map[string]any
|
||||
json.Unmarshal(data, &msgs)
|
||||
|
||||
if msgs[0]["content"] != "hello" {
|
||||
t.Errorf("expected plain string content, got %v", msgs[0]["content"])
|
||||
}
|
||||
if msgs[1]["reasoning_content"] != "thinking..." {
|
||||
t.Errorf("reasoning_content not preserved, got %v", msgs[1]["reasoning_content"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestSerializeMessages_WithMedia(t *testing.T) {
|
||||
messages := []Message{
|
||||
{Role: "user", Content: "describe this", Media: []string{"data:image/png;base64,abc123"}},
|
||||
}
|
||||
result := SerializeMessages(messages)
|
||||
|
||||
data, _ := json.Marshal(result)
|
||||
var msgs []map[string]any
|
||||
json.Unmarshal(data, &msgs)
|
||||
|
||||
content, ok := msgs[0]["content"].([]any)
|
||||
if !ok {
|
||||
t.Fatalf("expected array content for media message, got %T", msgs[0]["content"])
|
||||
}
|
||||
if len(content) != 2 {
|
||||
t.Fatalf("expected 2 content parts, got %d", len(content))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSerializeMessages_MediaWithToolCallID(t *testing.T) {
|
||||
messages := []Message{
|
||||
{Role: "tool", Content: "result", Media: []string{"data:image/png;base64,xyz"}, ToolCallID: "call_1"},
|
||||
}
|
||||
result := SerializeMessages(messages)
|
||||
|
||||
data, _ := json.Marshal(result)
|
||||
var msgs []map[string]any
|
||||
json.Unmarshal(data, &msgs)
|
||||
|
||||
if msgs[0]["tool_call_id"] != "call_1" {
|
||||
t.Errorf("tool_call_id not preserved, got %v", msgs[0]["tool_call_id"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestSerializeMessages_StripsSystemParts(t *testing.T) {
|
||||
messages := []Message{
|
||||
{
|
||||
Role: "system",
|
||||
Content: "you are helpful",
|
||||
SystemParts: []protocoltypes.ContentBlock{
|
||||
{Type: "text", Text: "you are helpful"},
|
||||
},
|
||||
},
|
||||
}
|
||||
result := SerializeMessages(messages)
|
||||
|
||||
data, _ := json.Marshal(result)
|
||||
if strings.Contains(string(data), "system_parts") {
|
||||
t.Error("system_parts should not appear in serialized output")
|
||||
}
|
||||
}
|
||||
|
||||
// --- ParseResponse tests ---
|
||||
|
||||
func TestParseResponse_BasicContent(t *testing.T) {
|
||||
body := `{"choices":[{"message":{"content":"hello world"},"finish_reason":"stop"}]}`
|
||||
out, err := ParseResponse(strings.NewReader(body))
|
||||
if err != nil {
|
||||
t.Fatalf("ParseResponse() error = %v", err)
|
||||
}
|
||||
if out.Content != "hello world" {
|
||||
t.Errorf("Content = %q, want %q", out.Content, "hello world")
|
||||
}
|
||||
if out.FinishReason != "stop" {
|
||||
t.Errorf("FinishReason = %q, want %q", out.FinishReason, "stop")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseResponse_EmptyChoices(t *testing.T) {
|
||||
body := `{"choices":[]}`
|
||||
out, err := ParseResponse(strings.NewReader(body))
|
||||
if err != nil {
|
||||
t.Fatalf("ParseResponse() error = %v", err)
|
||||
}
|
||||
if out.Content != "" {
|
||||
t.Errorf("Content = %q, want empty", out.Content)
|
||||
}
|
||||
if out.FinishReason != "stop" {
|
||||
t.Errorf("FinishReason = %q, want %q", out.FinishReason, "stop")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseResponse_WithToolCalls(t *testing.T) {
|
||||
body := `{"choices":[{"message":{"content":"","tool_calls":[{"id":"call_1","type":"function","function":{"name":"get_weather","arguments":"{\"city\":\"SF\"}"}}]},"finish_reason":"tool_calls"}]}`
|
||||
out, err := ParseResponse(strings.NewReader(body))
|
||||
if err != nil {
|
||||
t.Fatalf("ParseResponse() error = %v", err)
|
||||
}
|
||||
if len(out.ToolCalls) != 1 {
|
||||
t.Fatalf("len(ToolCalls) = %d, want 1", len(out.ToolCalls))
|
||||
}
|
||||
if out.ToolCalls[0].Name != "get_weather" {
|
||||
t.Errorf("ToolCalls[0].Name = %q, want %q", out.ToolCalls[0].Name, "get_weather")
|
||||
}
|
||||
if out.ToolCalls[0].Arguments["city"] != "SF" {
|
||||
t.Errorf("ToolCalls[0].Arguments[city] = %v, want SF", out.ToolCalls[0].Arguments["city"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseResponse_WithUsage(t *testing.T) {
|
||||
body := `{"choices":[{"message":{"content":"ok"},"finish_reason":"stop"}],"usage":{"prompt_tokens":10,"completion_tokens":5,"total_tokens":15}}`
|
||||
out, err := ParseResponse(strings.NewReader(body))
|
||||
if err != nil {
|
||||
t.Fatalf("ParseResponse() error = %v", err)
|
||||
}
|
||||
if out.Usage == nil {
|
||||
t.Fatal("Usage is nil")
|
||||
}
|
||||
if out.Usage.PromptTokens != 10 {
|
||||
t.Errorf("PromptTokens = %d, want 10", out.Usage.PromptTokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseResponse_WithReasoningContent(t *testing.T) {
|
||||
body := `{"choices":[{"message":{"content":"2","reasoning_content":"Let me think... 1+1=2"},"finish_reason":"stop"}]}`
|
||||
out, err := ParseResponse(strings.NewReader(body))
|
||||
if err != nil {
|
||||
t.Fatalf("ParseResponse() error = %v", err)
|
||||
}
|
||||
if out.ReasoningContent != "Let me think... 1+1=2" {
|
||||
t.Errorf("ReasoningContent = %q, want %q", out.ReasoningContent, "Let me think... 1+1=2")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseResponse_InvalidJSON(t *testing.T) {
|
||||
_, err := ParseResponse(strings.NewReader("not json"))
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid JSON")
|
||||
}
|
||||
}
|
||||
|
||||
// --- DecodeToolCallArguments tests ---
|
||||
|
||||
func TestDecodeToolCallArguments_ObjectJSON(t *testing.T) {
|
||||
raw := json.RawMessage(`{"city":"Seattle","units":"metric"}`)
|
||||
args := DecodeToolCallArguments(raw, "test")
|
||||
if args["city"] != "Seattle" {
|
||||
t.Errorf("city = %v, want Seattle", args["city"])
|
||||
}
|
||||
if args["units"] != "metric" {
|
||||
t.Errorf("units = %v, want metric", args["units"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeToolCallArguments_StringJSON(t *testing.T) {
|
||||
raw := json.RawMessage(`"{\"city\":\"SF\"}"`)
|
||||
args := DecodeToolCallArguments(raw, "test")
|
||||
if args["city"] != "SF" {
|
||||
t.Errorf("city = %v, want SF", args["city"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeToolCallArguments_EmptyInput(t *testing.T) {
|
||||
args := DecodeToolCallArguments(nil, "test")
|
||||
if len(args) != 0 {
|
||||
t.Errorf("expected empty map, got %v", args)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeToolCallArguments_NullInput(t *testing.T) {
|
||||
args := DecodeToolCallArguments(json.RawMessage(`null`), "test")
|
||||
if len(args) != 0 {
|
||||
t.Errorf("expected empty map, got %v", args)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeToolCallArguments_InvalidJSON(t *testing.T) {
|
||||
args := DecodeToolCallArguments(json.RawMessage(`not-json`), "test")
|
||||
if _, ok := args["raw"]; !ok {
|
||||
t.Error("expected 'raw' fallback key for invalid JSON")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeToolCallArguments_EmptyStringJSON(t *testing.T) {
|
||||
args := DecodeToolCallArguments(json.RawMessage(`" "`), "test")
|
||||
if len(args) != 0 {
|
||||
t.Errorf("expected empty map for whitespace string, got %v", args)
|
||||
}
|
||||
}
|
||||
|
||||
// --- HandleErrorResponse tests ---
|
||||
|
||||
func TestHandleErrorResponse_JSONError(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
w.Write([]byte(`{"error":"bad request"}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
resp, err := http.Get(server.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("http.Get() error = %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
err = HandleErrorResponse(resp, server.URL)
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "400") {
|
||||
t.Errorf("error should contain status code, got %v", err)
|
||||
}
|
||||
if strings.Contains(err.Error(), "HTML") {
|
||||
t.Errorf("should not mention HTML for JSON error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleErrorResponse_HTMLError(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/html")
|
||||
w.WriteHeader(http.StatusBadGateway)
|
||||
w.Write([]byte("<!DOCTYPE html><html><body>bad gateway</body></html>"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
resp, err := http.Get(server.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("http.Get() error = %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
err = HandleErrorResponse(resp, server.URL)
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "HTML instead of JSON") {
|
||||
t.Errorf("expected HTML error message, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// --- ReadAndParseResponse tests ---
|
||||
|
||||
func TestReadAndParseResponse_ValidJSON(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write([]byte(`{"choices":[{"message":{"content":"ok"},"finish_reason":"stop"}]}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
resp, err := http.Get(server.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("http.Get() error = %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
out, err := ReadAndParseResponse(resp, server.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadAndParseResponse() error = %v", err)
|
||||
}
|
||||
if out.Content != "ok" {
|
||||
t.Errorf("Content = %q, want %q", out.Content, "ok")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadAndParseResponse_HTMLResponse(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/html")
|
||||
w.Write([]byte("<!DOCTYPE html><html><body>login page</body></html>"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
resp, err := http.Get(server.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("http.Get() error = %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
_, err = ReadAndParseResponse(resp, server.URL)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for HTML response")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "HTML instead of JSON") {
|
||||
t.Errorf("expected HTML error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// --- LooksLikeHTML tests ---
|
||||
|
||||
func TestLooksLikeHTML_ContentTypeHTML(t *testing.T) {
|
||||
if !LooksLikeHTML(nil, "text/html; charset=utf-8") {
|
||||
t.Error("expected true for text/html content type")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLooksLikeHTML_ContentTypeXHTML(t *testing.T) {
|
||||
if !LooksLikeHTML(nil, "application/xhtml+xml") {
|
||||
t.Error("expected true for xhtml content type")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLooksLikeHTML_BodyPrefix(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
}{
|
||||
{"doctype", "<!DOCTYPE html><html>"},
|
||||
{"html tag", "<html><body>"},
|
||||
{"head tag", "<head><title>"},
|
||||
{"body tag", "<body>content"},
|
||||
{"whitespace before", " \n\t<!DOCTYPE html>"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if !LooksLikeHTML([]byte(tt.body), "application/json") {
|
||||
t.Errorf("expected true for body %q", tt.body)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLooksLikeHTML_NotHTML(t *testing.T) {
|
||||
if LooksLikeHTML([]byte(`{"error":"bad"}`), "application/json") {
|
||||
t.Error("expected false for JSON body")
|
||||
}
|
||||
}
|
||||
|
||||
// --- ResponsePreview tests ---
|
||||
|
||||
func TestResponsePreview_Short(t *testing.T) {
|
||||
got := ResponsePreview([]byte("hello"), 128)
|
||||
if got != "hello" {
|
||||
t.Errorf("got %q, want %q", got, "hello")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResponsePreview_Truncated(t *testing.T) {
|
||||
body := strings.Repeat("a", 200)
|
||||
got := ResponsePreview([]byte(body), 128)
|
||||
if len(got) != 131 { // 128 + "..."
|
||||
t.Errorf("len = %d, want 131", len(got))
|
||||
}
|
||||
if !strings.HasSuffix(got, "...") {
|
||||
t.Error("expected ... suffix")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResponsePreview_Empty(t *testing.T) {
|
||||
got := ResponsePreview([]byte(""), 128)
|
||||
if got != "<empty>" {
|
||||
t.Errorf("got %q, want %q", got, "<empty>")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResponsePreview_Whitespace(t *testing.T) {
|
||||
got := ResponsePreview([]byte(" \n\t "), 128)
|
||||
if got != "<empty>" {
|
||||
t.Errorf("got %q, want %q for whitespace-only body", got, "<empty>")
|
||||
}
|
||||
}
|
||||
|
||||
// --- AsInt tests ---
|
||||
|
||||
func TestAsInt(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
val any
|
||||
want int
|
||||
ok bool
|
||||
}{
|
||||
{"int", 42, 42, true},
|
||||
{"int64", int64(99), 99, true},
|
||||
{"float64", float64(512), 512, true},
|
||||
{"float32", float32(256), 256, true},
|
||||
{"string", "nope", 0, false},
|
||||
{"nil", nil, 0, false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, ok := AsInt(tt.val)
|
||||
if ok != tt.ok || got != tt.want {
|
||||
t.Errorf("AsInt(%v) = (%d, %v), want (%d, %v)", tt.val, got, ok, tt.want, tt.ok)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// --- AsFloat tests ---
|
||||
|
||||
func TestAsFloat(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
val any
|
||||
want float64
|
||||
ok bool
|
||||
}{
|
||||
{"float64", float64(0.7), 0.7, true},
|
||||
{"float32", float32(0.5), float64(float32(0.5)), true},
|
||||
{"int", 1, 1.0, true},
|
||||
{"int64", int64(100), 100.0, true},
|
||||
{"string", "nope", 0, false},
|
||||
{"nil", nil, 0, false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, ok := AsFloat(tt.val)
|
||||
if ok != tt.ok || got != tt.want {
|
||||
t.Errorf("AsFloat(%v) = (%f, %v), want (%f, %v)", tt.val, got, ok, tt.want, tt.ok)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// --- WrapHTMLResponseError tests ---
|
||||
|
||||
func TestWrapHTMLResponseError(t *testing.T) {
|
||||
err := WrapHTMLResponseError(502, []byte("<html>bad</html>"), "text/html", "https://api.example.com")
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
msg := err.Error()
|
||||
if !strings.Contains(msg, "502") {
|
||||
t.Errorf("expected status code in error, got %v", msg)
|
||||
}
|
||||
if !strings.Contains(msg, "https://api.example.com") {
|
||||
t.Errorf("expected api base in error, got %v", msg)
|
||||
}
|
||||
if !strings.Contains(msg, "HTML instead of JSON") {
|
||||
t.Errorf("expected HTML mention in error, got %v", msg)
|
||||
}
|
||||
}
|
||||
|
||||
// --- HandleErrorResponse with read failure ---
|
||||
|
||||
func TestHandleErrorResponse_EmptyBody(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
// empty body
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
resp, err := http.Get(server.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("http.Get() error = %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
err = HandleErrorResponse(resp, server.URL)
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "500") {
|
||||
t.Errorf("expected status code, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// --- ReadAndParseResponse with invalid JSON ---
|
||||
|
||||
func TestReadAndParseResponse_InvalidJSON(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write([]byte("not valid json"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
resp, err := http.Get(server.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("http.Get() error = %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
_, err = ReadAndParseResponse(resp, server.URL)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid JSON")
|
||||
}
|
||||
}
|
||||
|
||||
// --- ParseResponse with thought_signature (Google/Gemini) ---
|
||||
|
||||
func TestParseResponse_WithThoughtSignature(t *testing.T) {
|
||||
body := `{"choices":[{"message":{"content":"","tool_calls":[{"id":"call_1","type":"function","function":{"name":"test_tool","arguments":"{}"},"extra_content":{"google":{"thought_signature":"sig123"}}}]},"finish_reason":"tool_calls"}]}`
|
||||
out, err := ParseResponse(strings.NewReader(body))
|
||||
if err != nil {
|
||||
t.Fatalf("ParseResponse() error = %v", err)
|
||||
}
|
||||
if len(out.ToolCalls) != 1 {
|
||||
t.Fatalf("len(ToolCalls) = %d, want 1", len(out.ToolCalls))
|
||||
}
|
||||
if out.ToolCalls[0].ThoughtSignature != "sig123" {
|
||||
t.Errorf("ThoughtSignature = %q, want %q", out.ToolCalls[0].ThoughtSignature, "sig123")
|
||||
}
|
||||
if out.ToolCalls[0].ExtraContent == nil || out.ToolCalls[0].ExtraContent.Google == nil {
|
||||
t.Fatal("ExtraContent.Google is nil")
|
||||
}
|
||||
if out.ToolCalls[0].ExtraContent.Google.ThoughtSignature != "sig123" {
|
||||
t.Errorf("ExtraContent.Google.ThoughtSignature = %q, want %q",
|
||||
out.ToolCalls[0].ExtraContent.Google.ThoughtSignature, "sig123")
|
||||
}
|
||||
}
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
anthropicmessages "github.com/sipeed/picoclaw/pkg/providers/anthropic_messages"
|
||||
"github.com/sipeed/picoclaw/pkg/providers/azure"
|
||||
)
|
||||
|
||||
// createClaudeAuthProvider creates a Claude provider using OAuth credentials from auth store.
|
||||
@@ -94,6 +95,24 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
|
||||
cfg.RequestTimeout,
|
||||
), modelID, nil
|
||||
|
||||
case "azure", "azure-openai":
|
||||
// Azure OpenAI uses deployment-based URLs, api-key header auth,
|
||||
// and always sends max_completion_tokens.
|
||||
if cfg.APIKey == "" {
|
||||
return nil, "", fmt.Errorf("api_key is required for azure protocol")
|
||||
}
|
||||
if cfg.APIBase == "" {
|
||||
return nil, "", fmt.Errorf(
|
||||
"api_base is required for azure protocol (e.g., https://your-resource.openai.azure.com)",
|
||||
)
|
||||
}
|
||||
return azure.NewProviderWithTimeout(
|
||||
cfg.APIKey,
|
||||
cfg.APIBase,
|
||||
cfg.Proxy,
|
||||
cfg.RequestTimeout,
|
||||
), modelID, nil
|
||||
|
||||
case "litellm", "openrouter", "groq", "zhipu", "gemini", "nvidia",
|
||||
"ollama", "moonshot", "shengsuanyun", "deepseek", "cerebras",
|
||||
"vivgrid", "volcengine", "vllm", "qwen", "mistral", "avian",
|
||||
|
||||
@@ -64,6 +64,12 @@ func TestExtractProtocol(t *testing.T) {
|
||||
wantProtocol: "nvidia",
|
||||
wantModelID: "meta/llama-3.1-8b",
|
||||
},
|
||||
{
|
||||
name: "azure with prefix",
|
||||
model: "azure/my-gpt5-deployment",
|
||||
wantProtocol: "azure",
|
||||
wantModelID: "my-gpt5-deployment",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@@ -371,3 +377,69 @@ func TestCreateProviderFromConfig_RequestTimeoutPropagation(t *testing.T) {
|
||||
t.Fatalf("Chat() error = %q, want timeout-related error", errMsg)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateProviderFromConfig_Azure(t *testing.T) {
|
||||
cfg := &config.ModelConfig{
|
||||
ModelName: "azure-gpt5",
|
||||
Model: "azure/my-gpt5-deployment",
|
||||
APIKey: "test-azure-key",
|
||||
APIBase: "https://my-resource.openai.azure.com",
|
||||
}
|
||||
|
||||
provider, modelID, err := CreateProviderFromConfig(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateProviderFromConfig() error = %v", err)
|
||||
}
|
||||
if provider == nil {
|
||||
t.Fatal("CreateProviderFromConfig() returned nil provider")
|
||||
}
|
||||
if modelID != "my-gpt5-deployment" {
|
||||
t.Errorf("modelID = %q, want %q", modelID, "my-gpt5-deployment")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateProviderFromConfig_AzureOpenAIAlias(t *testing.T) {
|
||||
cfg := &config.ModelConfig{
|
||||
ModelName: "azure-gpt4",
|
||||
Model: "azure-openai/my-deployment",
|
||||
APIKey: "test-azure-key",
|
||||
APIBase: "https://my-resource.openai.azure.com",
|
||||
}
|
||||
|
||||
provider, modelID, err := CreateProviderFromConfig(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateProviderFromConfig() error = %v", err)
|
||||
}
|
||||
if provider == nil {
|
||||
t.Fatal("CreateProviderFromConfig() returned nil provider")
|
||||
}
|
||||
if modelID != "my-deployment" {
|
||||
t.Errorf("modelID = %q, want %q", modelID, "my-deployment")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateProviderFromConfig_AzureMissingAPIKey(t *testing.T) {
|
||||
cfg := &config.ModelConfig{
|
||||
ModelName: "azure-gpt5",
|
||||
Model: "azure/my-gpt5-deployment",
|
||||
APIBase: "https://my-resource.openai.azure.com",
|
||||
}
|
||||
|
||||
_, _, err := CreateProviderFromConfig(cfg)
|
||||
if err == nil {
|
||||
t.Fatal("CreateProviderFromConfig() expected error for missing API key")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateProviderFromConfig_AzureMissingAPIBase(t *testing.T) {
|
||||
cfg := &config.ModelConfig{
|
||||
ModelName: "azure-gpt5",
|
||||
Model: "azure/my-gpt5-deployment",
|
||||
APIKey: "test-azure-key",
|
||||
}
|
||||
|
||||
_, _, err := CreateProviderFromConfig(cfg)
|
||||
if err == nil {
|
||||
t.Fatal("CreateProviderFromConfig() expected error for missing API base")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,18 +1,16 @@
|
||||
package openai_compat
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/providers/common"
|
||||
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
|
||||
)
|
||||
|
||||
@@ -38,7 +36,7 @@ type Provider struct {
|
||||
|
||||
type Option func(*Provider)
|
||||
|
||||
const defaultRequestTimeout = 120 * time.Second
|
||||
const defaultRequestTimeout = common.DefaultRequestTimeout
|
||||
|
||||
func WithMaxTokensField(maxTokensField string) Option {
|
||||
return func(p *Provider) {
|
||||
@@ -55,25 +53,10 @@ func WithRequestTimeout(timeout time.Duration) Option {
|
||||
}
|
||||
|
||||
func NewProvider(apiKey, apiBase, proxy string, opts ...Option) *Provider {
|
||||
client := &http.Client{
|
||||
Timeout: defaultRequestTimeout,
|
||||
}
|
||||
|
||||
if proxy != "" {
|
||||
parsed, err := url.Parse(proxy)
|
||||
if err == nil {
|
||||
client.Transport = &http.Transport{
|
||||
Proxy: http.ProxyURL(parsed),
|
||||
}
|
||||
} else {
|
||||
log.Printf("openai_compat: invalid proxy URL %q: %v", proxy, err)
|
||||
}
|
||||
}
|
||||
|
||||
p := &Provider{
|
||||
apiKey: apiKey,
|
||||
apiBase: strings.TrimRight(apiBase, "/"),
|
||||
httpClient: client,
|
||||
httpClient: common.NewHTTPClient(proxy),
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
@@ -117,7 +100,7 @@ func (p *Provider) Chat(
|
||||
|
||||
requestBody := map[string]any{
|
||||
"model": model,
|
||||
"messages": serializeMessages(messages),
|
||||
"messages": common.SerializeMessages(messages),
|
||||
}
|
||||
|
||||
if len(tools) > 0 {
|
||||
@@ -125,7 +108,7 @@ func (p *Provider) Chat(
|
||||
requestBody["tool_choice"] = "auto"
|
||||
}
|
||||
|
||||
if maxTokens, ok := asInt(options["max_tokens"]); ok {
|
||||
if maxTokens, ok := common.AsInt(options["max_tokens"]); ok {
|
||||
// Use configured maxTokensField if specified, otherwise fallback to model-based detection
|
||||
fieldName := p.maxTokensField
|
||||
if fieldName == "" {
|
||||
@@ -141,7 +124,7 @@ func (p *Provider) Chat(
|
||||
requestBody[fieldName] = maxTokens
|
||||
}
|
||||
|
||||
if temperature, ok := asFloat(options["temperature"]); ok {
|
||||
if temperature, ok := common.AsFloat(options["temperature"]); ok {
|
||||
lowerModel := strings.ToLower(model)
|
||||
// Kimi k2 models only support temperature=1.
|
||||
if strings.Contains(lowerModel, "kimi") && strings.Contains(lowerModel, "k2") {
|
||||
@@ -185,275 +168,11 @@ func (p *Provider) Chat(
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
|
||||
// Non-200: read a prefix to tell HTML error page apart from JSON error body.
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, readErr := io.ReadAll(io.LimitReader(resp.Body, 256))
|
||||
if readErr != nil {
|
||||
return nil, fmt.Errorf("failed to read response: %w", readErr)
|
||||
}
|
||||
if looksLikeHTML(body, contentType) {
|
||||
return nil, wrapHTMLResponseError(resp.StatusCode, body, contentType, p.apiBase)
|
||||
}
|
||||
return nil, fmt.Errorf(
|
||||
"API request failed:\n Status: %d\n Body: %s",
|
||||
resp.StatusCode,
|
||||
responsePreview(body, 128),
|
||||
)
|
||||
return nil, common.HandleErrorResponse(resp, p.apiBase)
|
||||
}
|
||||
|
||||
// Peek without consuming so the full stream reaches the JSON decoder.
|
||||
reader := bufio.NewReader(resp.Body)
|
||||
prefix, err := reader.Peek(256) // io.EOF/ErrBufferFull are normal; only real errors abort
|
||||
if err != nil && err != io.EOF && err != bufio.ErrBufferFull {
|
||||
return nil, fmt.Errorf("failed to inspect response: %w", err)
|
||||
}
|
||||
if looksLikeHTML(prefix, contentType) {
|
||||
return nil, wrapHTMLResponseError(resp.StatusCode, prefix, contentType, p.apiBase)
|
||||
}
|
||||
|
||||
out, err := parseResponse(reader)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse JSON response: %w", err)
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func wrapHTMLResponseError(statusCode int, body []byte, contentType, apiBase string) error {
|
||||
respPreview := responsePreview(body, 128)
|
||||
return fmt.Errorf(
|
||||
"API request failed: %s returned HTML instead of JSON (content-type: %s); check api_base or proxy configuration.\n Status: %d\n Body: %s",
|
||||
apiBase,
|
||||
contentType,
|
||||
statusCode,
|
||||
respPreview,
|
||||
)
|
||||
}
|
||||
|
||||
func looksLikeHTML(body []byte, contentType string) bool {
|
||||
contentType = strings.ToLower(strings.TrimSpace(contentType))
|
||||
if strings.Contains(contentType, "text/html") || strings.Contains(contentType, "application/xhtml+xml") {
|
||||
return true
|
||||
}
|
||||
prefix := bytes.ToLower(leadingTrimmedPrefix(body, 128))
|
||||
return bytes.HasPrefix(prefix, []byte("<!doctype html")) ||
|
||||
bytes.HasPrefix(prefix, []byte("<html")) ||
|
||||
bytes.HasPrefix(prefix, []byte("<head")) ||
|
||||
bytes.HasPrefix(prefix, []byte("<body"))
|
||||
}
|
||||
|
||||
func leadingTrimmedPrefix(body []byte, maxLen int) []byte {
|
||||
i := 0
|
||||
for i < len(body) {
|
||||
switch body[i] {
|
||||
case ' ', '\t', '\n', '\r', '\f', '\v':
|
||||
i++
|
||||
default:
|
||||
end := i + maxLen
|
||||
if end > len(body) {
|
||||
end = len(body)
|
||||
}
|
||||
return body[i:end]
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func responsePreview(body []byte, maxLen int) string {
|
||||
trimmed := bytes.TrimSpace(body)
|
||||
if len(trimmed) == 0 {
|
||||
return "<empty>"
|
||||
}
|
||||
if len(trimmed) <= maxLen {
|
||||
return string(trimmed)
|
||||
}
|
||||
return string(trimmed[:maxLen]) + "..."
|
||||
}
|
||||
|
||||
func parseResponse(body io.Reader) (*LLMResponse, error) {
|
||||
var apiResponse struct {
|
||||
Choices []struct {
|
||||
Message struct {
|
||||
Content string `json:"content"`
|
||||
ReasoningContent string `json:"reasoning_content"`
|
||||
Reasoning string `json:"reasoning"`
|
||||
ReasoningDetails []ReasoningDetail `json:"reasoning_details"`
|
||||
ToolCalls []struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Function *struct {
|
||||
Name string `json:"name"`
|
||||
Arguments json.RawMessage `json:"arguments"`
|
||||
} `json:"function"`
|
||||
ExtraContent *struct {
|
||||
Google *struct {
|
||||
ThoughtSignature string `json:"thought_signature"`
|
||||
} `json:"google"`
|
||||
} `json:"extra_content"`
|
||||
} `json:"tool_calls"`
|
||||
} `json:"message"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
} `json:"choices"`
|
||||
Usage *UsageInfo `json:"usage"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(body).Decode(&apiResponse); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode response: %w", err)
|
||||
}
|
||||
|
||||
if len(apiResponse.Choices) == 0 {
|
||||
return &LLMResponse{
|
||||
Content: "",
|
||||
FinishReason: "stop",
|
||||
}, nil
|
||||
}
|
||||
|
||||
choice := apiResponse.Choices[0]
|
||||
toolCalls := make([]ToolCall, 0, len(choice.Message.ToolCalls))
|
||||
for _, tc := range choice.Message.ToolCalls {
|
||||
arguments := make(map[string]any)
|
||||
name := ""
|
||||
|
||||
// Extract thought_signature from Gemini/Google-specific extra content
|
||||
thoughtSignature := ""
|
||||
if tc.ExtraContent != nil && tc.ExtraContent.Google != nil {
|
||||
thoughtSignature = tc.ExtraContent.Google.ThoughtSignature
|
||||
}
|
||||
|
||||
if tc.Function != nil {
|
||||
name = tc.Function.Name
|
||||
arguments = decodeToolCallArguments(tc.Function.Arguments, name)
|
||||
}
|
||||
|
||||
// Build ToolCall with ExtraContent for Gemini 3 thought_signature persistence
|
||||
toolCall := ToolCall{
|
||||
ID: tc.ID,
|
||||
Name: name,
|
||||
Arguments: arguments,
|
||||
ThoughtSignature: thoughtSignature,
|
||||
}
|
||||
|
||||
if thoughtSignature != "" {
|
||||
toolCall.ExtraContent = &ExtraContent{
|
||||
Google: &GoogleExtra{
|
||||
ThoughtSignature: thoughtSignature,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
toolCalls = append(toolCalls, toolCall)
|
||||
}
|
||||
|
||||
return &LLMResponse{
|
||||
Content: choice.Message.Content,
|
||||
ReasoningContent: choice.Message.ReasoningContent,
|
||||
Reasoning: choice.Message.Reasoning,
|
||||
ReasoningDetails: choice.Message.ReasoningDetails,
|
||||
ToolCalls: toolCalls,
|
||||
FinishReason: choice.FinishReason,
|
||||
Usage: apiResponse.Usage,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func decodeToolCallArguments(raw json.RawMessage, name string) map[string]any {
|
||||
arguments := make(map[string]any)
|
||||
raw = bytes.TrimSpace(raw)
|
||||
if len(raw) == 0 || bytes.Equal(raw, []byte("null")) {
|
||||
return arguments
|
||||
}
|
||||
|
||||
var decoded any
|
||||
if err := json.Unmarshal(raw, &decoded); err != nil {
|
||||
log.Printf("openai_compat: failed to decode tool call arguments payload for %q: %v", name, err)
|
||||
arguments["raw"] = string(raw)
|
||||
return arguments
|
||||
}
|
||||
|
||||
switch v := decoded.(type) {
|
||||
case string:
|
||||
if strings.TrimSpace(v) == "" {
|
||||
return arguments
|
||||
}
|
||||
if err := json.Unmarshal([]byte(v), &arguments); err != nil {
|
||||
log.Printf("openai_compat: failed to decode tool call arguments for %q: %v", name, err)
|
||||
arguments["raw"] = v
|
||||
}
|
||||
return arguments
|
||||
case map[string]any:
|
||||
return v
|
||||
default:
|
||||
log.Printf("openai_compat: unsupported tool call arguments type for %q: %T", name, decoded)
|
||||
arguments["raw"] = string(raw)
|
||||
return arguments
|
||||
}
|
||||
}
|
||||
|
||||
// openaiMessage is the wire-format message for OpenAI-compatible APIs.
|
||||
// It mirrors protocoltypes.Message but omits SystemParts, which is an
|
||||
// internal field that would be unknown to third-party endpoints.
|
||||
type openaiMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
ReasoningContent string `json:"reasoning_content,omitempty"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
}
|
||||
|
||||
// serializeMessages converts internal Message structs to the OpenAI wire format.
|
||||
// - Strips SystemParts (unknown to third-party endpoints)
|
||||
// - Converts messages with Media to multipart content format (text + image_url parts)
|
||||
// - Preserves ToolCallID, ToolCalls, and ReasoningContent for all messages
|
||||
func serializeMessages(messages []Message) []any {
|
||||
out := make([]any, 0, len(messages))
|
||||
for _, m := range messages {
|
||||
if len(m.Media) == 0 {
|
||||
out = append(out, openaiMessage{
|
||||
Role: m.Role,
|
||||
Content: m.Content,
|
||||
ReasoningContent: m.ReasoningContent,
|
||||
ToolCalls: m.ToolCalls,
|
||||
ToolCallID: m.ToolCallID,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
// Multipart content format for messages with media
|
||||
parts := make([]map[string]any, 0, 1+len(m.Media))
|
||||
if m.Content != "" {
|
||||
parts = append(parts, map[string]any{
|
||||
"type": "text",
|
||||
"text": m.Content,
|
||||
})
|
||||
}
|
||||
for _, mediaURL := range m.Media {
|
||||
if strings.HasPrefix(mediaURL, "data:image/") {
|
||||
parts = append(parts, map[string]any{
|
||||
"type": "image_url",
|
||||
"image_url": map[string]any{
|
||||
"url": mediaURL,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
msg := map[string]any{
|
||||
"role": m.Role,
|
||||
"content": parts,
|
||||
}
|
||||
if m.ToolCallID != "" {
|
||||
msg["tool_call_id"] = m.ToolCallID
|
||||
}
|
||||
if len(m.ToolCalls) > 0 {
|
||||
msg["tool_calls"] = m.ToolCalls
|
||||
}
|
||||
if m.ReasoningContent != "" {
|
||||
msg["reasoning_content"] = m.ReasoningContent
|
||||
}
|
||||
out = append(out, msg)
|
||||
}
|
||||
return out
|
||||
return common.ReadAndParseResponse(resp, p.apiBase)
|
||||
}
|
||||
|
||||
func normalizeModel(model, apiBase string) string {
|
||||
@@ -476,36 +195,6 @@ func normalizeModel(model, apiBase string) string {
|
||||
}
|
||||
}
|
||||
|
||||
func asInt(v any) (int, bool) {
|
||||
switch val := v.(type) {
|
||||
case int:
|
||||
return val, true
|
||||
case int64:
|
||||
return int(val), true
|
||||
case float64:
|
||||
return int(val), true
|
||||
case float32:
|
||||
return int(val), true
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
func asFloat(v any) (float64, bool) {
|
||||
switch val := v.(type) {
|
||||
case float64:
|
||||
return val, true
|
||||
case float32:
|
||||
return float64(val), true
|
||||
case int:
|
||||
return float64(val), true
|
||||
case int64:
|
||||
return float64(val), true
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
// supportsPromptCacheKey reports whether the given API base is known to
|
||||
// support the prompt_cache_key request field. Currently only OpenAI's own
|
||||
// API and Azure OpenAI support this. All other OpenAI-compatible providers
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/providers/common"
|
||||
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
|
||||
)
|
||||
|
||||
@@ -648,7 +649,7 @@ func TestSerializeMessages_PlainText(t *testing.T) {
|
||||
{Role: "user", Content: "hello"},
|
||||
{Role: "assistant", Content: "hi", ReasoningContent: "thinking..."},
|
||||
}
|
||||
result := serializeMessages(messages)
|
||||
result := common.SerializeMessages(messages)
|
||||
|
||||
data, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
@@ -670,7 +671,7 @@ func TestSerializeMessages_WithMedia(t *testing.T) {
|
||||
messages := []protocoltypes.Message{
|
||||
{Role: "user", Content: "describe this", Media: []string{"data:image/png;base64,abc123"}},
|
||||
}
|
||||
result := serializeMessages(messages)
|
||||
result := common.SerializeMessages(messages)
|
||||
|
||||
data, _ := json.Marshal(result)
|
||||
var msgs []map[string]any
|
||||
@@ -703,7 +704,7 @@ func TestSerializeMessages_MediaWithToolCallID(t *testing.T) {
|
||||
messages := []protocoltypes.Message{
|
||||
{Role: "tool", Content: "image result", Media: []string{"data:image/png;base64,xyz"}, ToolCallID: "call_1"},
|
||||
}
|
||||
result := serializeMessages(messages)
|
||||
result := common.SerializeMessages(messages)
|
||||
|
||||
data, _ := json.Marshal(result)
|
||||
var msgs []map[string]any
|
||||
@@ -833,7 +834,7 @@ func TestSerializeMessages_StripsSystemParts(t *testing.T) {
|
||||
},
|
||||
},
|
||||
}
|
||||
result := serializeMessages(messages)
|
||||
result := common.SerializeMessages(messages)
|
||||
|
||||
data, _ := json.Marshal(result)
|
||||
raw := string(data)
|
||||
|
||||
Reference in New Issue
Block a user