feat(provider): add support for azure openai provider (#1422)

* Add support for azure openai provider

* Add checks for deployment model name

* Apply suggestion from @Copilot

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Addressing @Copilot suggestion to remove the init() function which seemed redundant

* Fix readme

* Fix linting checks

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
Kunal Karmakar
2026-03-14 20:22:34 +05:30
committed by GitHub
parent 0f700a6bf0
commit 5fb4b3bedf
16 changed files with 1446 additions and 323 deletions
+1
View File
@@ -991,6 +991,7 @@ Cette conception permet également le **support multi-agent** avec une sélectio
| **BytePlus** | `byteplus/` | `https://ark.ap-southeast.bytepluses.com/api/v3` | OpenAI | [Obtenir Clé](https://www.byteplus.com/) |
| **LongCat** | `longcat/` | `https://api.longcat.chat/openai` | OpenAI | [Obtenir une clé](https://longcat.chat/platform) |
| **ModelScope (魔搭)**| `modelscope/` | `https://api-inference.modelscope.cn/v1` | OpenAI | [Obtenir un Token](https://modelscope.cn/my/tokens) |
| **Azure OpenAI** | `azure/` | `https://{resource}.openai.azure.com` | Azure | [Obtenir Clé](https://portal.azure.com) |
| **Antigravity** | `antigravity/` | Google Cloud | Custom | OAuth uniquement |
| **GitHub Copilot** | `github-copilot/` | `localhost:4321` | gRPC | - |
+1
View File
@@ -935,6 +935,7 @@ HEARTBEAT_OK 応答 ユーザーが直接結果を受け取る
| **BytePlus** | `byteplus/` | `https://ark.ap-southeast.bytepluses.com/api/v3` | OpenAI | [キーを取得](https://www.byteplus.com) |
| **LongCat** | `longcat/` | `https://api.longcat.chat/openai` | OpenAI | [キーを取得](https://longcat.chat/platform) |
| **ModelScope (魔搭)**| `modelscope/` | `https://api-inference.modelscope.cn/v1` | OpenAI | [トークンを取得](https://modelscope.cn/my/tokens) |
| **Azure OpenAI** | `azure/` | `https://{resource}.openai.azure.com` | Azure | [キーを取得](https://portal.azure.com) |
| **Antigravity** | `antigravity/` | Google Cloud | カスタム | OAuthのみ |
| **GitHub Copilot** | `github-copilot/` | `localhost:4321` | gRPC | - |
+2
View File
@@ -1006,6 +1006,7 @@ The subagent has access to tools (message, web_search, etc.) and can communicate
| `groq` | LLM + **Voice transcription** (Whisper) | [console.groq.com](https://console.groq.com) |
| `cerebras` | LLM (Cerebras direct) | [cerebras.ai](https://cerebras.ai) |
| `vivgrid` | LLM (Vivgrid direct) | [vivgrid.com](https://vivgrid.com) |
| `azure` | LLM (Azure OpenAI) | [portal.azure.com](https://portal.azure.com) |
### Model Configuration (model_list)
@@ -1042,6 +1043,7 @@ This design also enables **multi-agent support** with flexible provider selectio
| **Vivgrid** | `vivgrid/` | `https://api.vivgrid.com/v1` | OpenAI | [Get Key](https://vivgrid.com) |
| **LongCat** | `longcat/` | `https://api.longcat.chat/openai` | OpenAI | [Get Key](https://longcat.chat/platform) |
| **ModelScope (魔搭)**| `modelscope/` | `https://api-inference.modelscope.cn/v1` | OpenAI | [Get Token](https://modelscope.cn/my/tokens) |
| **Azure OpenAI** | `azure/` | `https://{resource}.openai.azure.com` | Azure | [Get Key](https://portal.azure.com) |
| **Antigravity** | `antigravity/` | Google Cloud | Custom | OAuth only |
| **GitHub Copilot** | `github-copilot/` | `localhost:4321` | gRPC | - |
+1
View File
@@ -987,6 +987,7 @@ Este design também possibilita o **suporte multi-agent** com seleção flexíve
| **BytePlus** | `byteplus/` | `https://ark.ap-southeast.bytepluses.com/api/v3` | OpenAI | [Obter Chave](https://www.byteplus.com) |
| **LongCat** | `longcat/` | `https://api.longcat.chat/openai` | OpenAI | [Obter Chave](https://longcat.chat/platform) |
| **ModelScope (魔搭)**| `modelscope/` | `https://api-inference.modelscope.cn/v1` | OpenAI | [Obter Token](https://modelscope.cn/my/tokens) |
| **Azure OpenAI** | `azure/` | `https://{resource}.openai.azure.com` | Azure | [Obter Chave](https://portal.azure.com) |
| **Antigravity** | `antigravity/` | Google Cloud | Custom | Apenas OAuth |
| **GitHub Copilot** | `github-copilot/` | `localhost:4321` | gRPC | - |
+1
View File
@@ -956,6 +956,7 @@ Thiết kế này cũng cho phép **hỗ trợ đa tác nhân** với lựa ch
| **BytePlus** | `byteplus/` | `https://ark.ap-southeast.bytepluses.com/api/v3` | OpenAI | [Lấy Khóa](https://www.byteplus.com) |
| **LongCat** | `longcat/` | `https://api.longcat.chat/openai` | OpenAI | [Lấy Key](https://longcat.chat/platform) |
| **ModelScope (魔搭)**| `modelscope/` | `https://api-inference.modelscope.cn/v1` | OpenAI | [Lấy Token](https://modelscope.cn/my/tokens) |
| **Azure OpenAI** | `azure/` | `https://{resource}.openai.azure.com` | Azure | [Lấy Khóa](https://portal.azure.com) |
| **Antigravity** | `antigravity/` | Google Cloud | Tùy chỉnh | Chỉ OAuth |
| **GitHub Copilot** | `github-copilot/` | `localhost:4321` | gRPC | - |
+1
View File
@@ -528,6 +528,7 @@ Agent 读取 HEARTBEAT.md
| **BytePlus** | `byteplus/` | `https://ark.ap-southeast.bytepluses.com/api/v3` | OpenAI | [获取密钥](https://www.byteplus.com) |
| **LongCat** | `longcat/` | `https://api.longcat.chat/openai` | OpenAI | [获取密钥](https://longcat.chat/platform) |
| **ModelScope (魔搭)**| `modelscope/` | `https://api-inference.modelscope.cn/v1` | OpenAI | [获取 Token](https://modelscope.cn/my/tokens) |
| **Azure OpenAI** | `azure/` | `https://{resource}.openai.azure.com` | Azure | [获取密钥](https://portal.azure.com) |
| **Antigravity** | `antigravity/` | Google Cloud | 自定义 | 仅 OAuth |
| **GitHub Copilot** | `github-copilot/` | `localhost:4321` | gRPC | - |
+6
View File
@@ -53,6 +53,12 @@
"api_key": "your-modelscope-access-token",
"api_base": "https://api-inference.modelscope.cn/v1"
},
{
"model_name": "azure-gpt5",
"model": "azure/my-gpt5-deployment",
"api_key": "your-azure-api-key",
"api_base": "https://your-resource.openai.azure.com"
},
{
"model_name": "loadbalanced-gpt-5.4",
"model": "openai/gpt-5.4",
+9
View File
@@ -384,6 +384,15 @@ func DefaultConfig() *Config {
APIBase: "http://localhost:8000/v1",
APIKey: "",
},
// Azure OpenAI - https://portal.azure.com
// model_name is a user-friendly alias; the model field's path after "azure/" is your deployment name
{
ModelName: "azure-gpt5",
Model: "azure/my-gpt5-deployment",
APIBase: "https://your-resource.openai.azure.com",
APIKey: "",
},
},
Gateway: GatewayConfig{
Host: "127.0.0.1",
+150
View File
@@ -0,0 +1,150 @@
package azure
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"net/url"
"strings"
"time"
"github.com/sipeed/picoclaw/pkg/providers/common"
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
)
type (
LLMResponse = protocoltypes.LLMResponse
Message = protocoltypes.Message
ToolDefinition = protocoltypes.ToolDefinition
)
const (
// azureAPIVersion is the Azure OpenAI API version used for all requests.
azureAPIVersion = "2024-10-21"
defaultRequestTimeout = common.DefaultRequestTimeout
)
// Provider implements the LLM provider interface for Azure OpenAI endpoints.
// It handles Azure-specific authentication (api-key header), URL construction
// (deployment-based), and request body formatting (max_completion_tokens, no model field).
type Provider struct {
apiKey string
apiBase string
httpClient *http.Client
}
// Option configures the Azure Provider.
type Option func(*Provider)
// WithRequestTimeout sets the HTTP request timeout.
func WithRequestTimeout(timeout time.Duration) Option {
return func(p *Provider) {
if timeout > 0 {
p.httpClient.Timeout = timeout
}
}
}
// NewProvider creates a new Azure OpenAI provider.
func NewProvider(apiKey, apiBase, proxy string, opts ...Option) *Provider {
p := &Provider{
apiKey: apiKey,
apiBase: strings.TrimRight(apiBase, "/"),
httpClient: common.NewHTTPClient(proxy),
}
for _, opt := range opts {
if opt != nil {
opt(p)
}
}
return p
}
// NewProviderWithTimeout creates a new Azure OpenAI provider with a custom request timeout in seconds.
func NewProviderWithTimeout(apiKey, apiBase, proxy string, requestTimeoutSeconds int) *Provider {
return NewProvider(
apiKey, apiBase, proxy,
WithRequestTimeout(time.Duration(requestTimeoutSeconds)*time.Second),
)
}
// Chat sends a chat completion request to the Azure OpenAI endpoint.
// The model parameter is used as the Azure deployment name in the URL.
func (p *Provider) Chat(
ctx context.Context,
messages []Message,
tools []ToolDefinition,
model string,
options map[string]any,
) (*LLMResponse, error) {
if p.apiBase == "" {
return nil, fmt.Errorf("Azure API base not configured")
}
// model is the deployment name for Azure OpenAI
deployment := model
// Build Azure-specific URL safely using url.JoinPath and query encoding
// to prevent path traversal or query injection via deployment names.
base, err := url.JoinPath(p.apiBase, "openai/deployments", deployment, "chat/completions")
if err != nil {
return nil, fmt.Errorf("failed to build Azure request URL: %w", err)
}
requestURL := base + "?api-version=" + azureAPIVersion
// Build request body — no "model" field (Azure infers from deployment URL)
requestBody := map[string]any{
"messages": common.SerializeMessages(messages),
}
if len(tools) > 0 {
requestBody["tools"] = tools
requestBody["tool_choice"] = "auto"
}
// Azure OpenAI always uses max_completion_tokens
if maxTokens, ok := common.AsInt(options["max_tokens"]); ok {
requestBody["max_completion_tokens"] = maxTokens
}
if temperature, ok := common.AsFloat(options["temperature"]); ok {
requestBody["temperature"] = temperature
}
jsonData, err := json.Marshal(requestBody)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
req, err := http.NewRequestWithContext(ctx, "POST", requestURL, bytes.NewReader(jsonData))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
// Azure uses api-key header instead of Authorization: Bearer
req.Header.Set("Content-Type", "application/json")
if p.apiKey != "" {
req.Header.Set("api-key", p.apiKey)
}
resp, err := p.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, common.HandleErrorResponse(resp, p.apiBase)
}
return common.ReadAndParseResponse(resp, p.apiBase)
}
// GetDefaultModel returns an empty string as Azure deployments are user-configured.
func (p *Provider) GetDefaultModel() string {
return ""
}
+232
View File
@@ -0,0 +1,232 @@
package azure
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"time"
)
// writeValidResponse writes a minimal valid Azure OpenAI chat completion response.
func writeValidResponse(w http.ResponseWriter) {
resp := map[string]any{
"choices": []map[string]any{
{
"message": map[string]any{"content": "ok"},
"finish_reason": "stop",
},
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}
func TestProviderChat_AzureURLConstruction(t *testing.T) {
var capturedPath string
var capturedAPIVersion string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedPath = r.URL.Path
capturedAPIVersion = r.URL.Query().Get("api-version")
writeValidResponse(w)
}))
defer server.Close()
p := NewProvider("test-key", server.URL, "")
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "my-gpt5-deployment", nil)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
wantPath := "/openai/deployments/my-gpt5-deployment/chat/completions"
if capturedPath != wantPath {
t.Errorf("URL path = %q, want %q", capturedPath, wantPath)
}
if capturedAPIVersion != azureAPIVersion {
t.Errorf("api-version = %q, want %q", capturedAPIVersion, azureAPIVersion)
}
}
func TestProviderChat_AzureAuthHeader(t *testing.T) {
var capturedAPIKey string
var capturedAuth string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedAPIKey = r.Header.Get("api-key")
capturedAuth = r.Header.Get("Authorization")
writeValidResponse(w)
}))
defer server.Close()
p := NewProvider("test-azure-key", server.URL, "")
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "deployment", nil)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
if capturedAPIKey != "test-azure-key" {
t.Errorf("api-key header = %q, want %q", capturedAPIKey, "test-azure-key")
}
if capturedAuth != "" {
t.Errorf("Authorization header should be empty, got %q", capturedAuth)
}
}
func TestProviderChat_AzureOmitsModelFromBody(t *testing.T) {
var requestBody map[string]any
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
json.NewDecoder(r.Body).Decode(&requestBody)
writeValidResponse(w)
}))
defer server.Close()
p := NewProvider("test-key", server.URL, "")
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "deployment", nil)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
if _, exists := requestBody["model"]; exists {
t.Error("request body should not contain 'model' field for Azure OpenAI")
}
}
func TestProviderChat_AzureUsesMaxCompletionTokens(t *testing.T) {
var requestBody map[string]any
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
json.NewDecoder(r.Body).Decode(&requestBody)
writeValidResponse(w)
}))
defer server.Close()
p := NewProvider("test-key", server.URL, "")
_, err := p.Chat(
t.Context(),
[]Message{{Role: "user", Content: "hi"}},
nil,
"deployment",
map[string]any{"max_tokens": 2048},
)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
if _, exists := requestBody["max_completion_tokens"]; !exists {
t.Error("request body should contain 'max_completion_tokens'")
}
if _, exists := requestBody["max_tokens"]; exists {
t.Error("request body should not contain 'max_tokens'")
}
}
func TestProviderChat_AzureHTTPError(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, `{"error":"unauthorized"}`, http.StatusUnauthorized)
}))
defer server.Close()
p := NewProvider("bad-key", server.URL, "")
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "deployment", nil)
if err == nil {
t.Fatal("expected error, got nil")
}
}
func TestProviderChat_AzureParseToolCalls(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
resp := map[string]any{
"choices": []map[string]any{
{
"message": map[string]any{
"content": "",
"tool_calls": []map[string]any{
{
"id": "call_1",
"type": "function",
"function": map[string]any{
"name": "get_weather",
"arguments": `{"city":"Seattle"}`,
},
},
},
},
"finish_reason": "tool_calls",
},
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
p := NewProvider("test-key", server.URL, "")
out, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "weather?"}}, nil, "deployment", nil)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
if len(out.ToolCalls) != 1 {
t.Fatalf("len(ToolCalls) = %d, want 1", len(out.ToolCalls))
}
if out.ToolCalls[0].Name != "get_weather" {
t.Errorf("ToolCalls[0].Name = %q, want %q", out.ToolCalls[0].Name, "get_weather")
}
}
func TestProvider_AzureEmptyAPIBase(t *testing.T) {
p := NewProvider("test-key", "", "")
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "deployment", nil)
if err == nil {
t.Fatal("expected error for empty API base")
}
}
func TestProvider_AzureRequestTimeoutDefault(t *testing.T) {
p := NewProvider("test-key", "https://example.com", "")
if p.httpClient.Timeout != defaultRequestTimeout {
t.Errorf("timeout = %v, want %v", p.httpClient.Timeout, defaultRequestTimeout)
}
}
func TestProvider_AzureRequestTimeoutOverride(t *testing.T) {
p := NewProvider("test-key", "https://example.com", "", WithRequestTimeout(300*time.Second))
if p.httpClient.Timeout != 300*time.Second {
t.Errorf("timeout = %v, want %v", p.httpClient.Timeout, 300*time.Second)
}
}
func TestProvider_AzureNewProviderWithTimeout(t *testing.T) {
p := NewProviderWithTimeout("test-key", "https://example.com", "", 180)
if p.httpClient.Timeout != 180*time.Second {
t.Errorf("timeout = %v, want %v", p.httpClient.Timeout, 180*time.Second)
}
}
func TestProviderChat_AzureDeploymentNameEscaped(t *testing.T) {
var capturedPath string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedPath = r.URL.RawPath // use RawPath to see percent-encoding
if capturedPath == "" {
capturedPath = r.URL.Path
}
writeValidResponse(w)
}))
defer server.Close()
p := NewProvider("test-key", server.URL, "")
// Deployment name with characters that could cause path injection
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "my deploy/../../admin", nil)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
// The slash and special chars in the deployment name must be escaped, not treated as path separators
if capturedPath == "/openai/deployments/my deploy/../../admin/chat/completions" {
t.Fatal("deployment name was interpolated without escaping — path injection possible")
}
}
+380
View File
@@ -0,0 +1,380 @@
// PicoClaw - Ultra-lightweight personal AI agent
// License: MIT
//
// Copyright (c) 2026 PicoClaw contributors
// Package common provides shared utilities used by multiple LLM provider
// implementations (openai_compat, azure, etc.).
package common
import (
"bufio"
"bytes"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"net/url"
"strings"
"time"
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
)
// Re-export protocol types used across providers.
type (
ToolCall = protocoltypes.ToolCall
FunctionCall = protocoltypes.FunctionCall
LLMResponse = protocoltypes.LLMResponse
UsageInfo = protocoltypes.UsageInfo
Message = protocoltypes.Message
ToolDefinition = protocoltypes.ToolDefinition
ToolFunctionDefinition = protocoltypes.ToolFunctionDefinition
ExtraContent = protocoltypes.ExtraContent
GoogleExtra = protocoltypes.GoogleExtra
ReasoningDetail = protocoltypes.ReasoningDetail
)
const DefaultRequestTimeout = 120 * time.Second
// NewHTTPClient creates an *http.Client with an optional proxy and the default timeout.
func NewHTTPClient(proxy string) *http.Client {
client := &http.Client{
Timeout: DefaultRequestTimeout,
}
if proxy != "" {
parsed, err := url.Parse(proxy)
if err == nil {
// Preserve http.DefaultTransport settings (TLS, HTTP/2, timeouts, etc.)
if base, ok := http.DefaultTransport.(*http.Transport); ok {
tr := base.Clone()
tr.Proxy = http.ProxyURL(parsed)
client.Transport = tr
} else {
// Fallback: minimal transport if DefaultTransport is not *http.Transport.
client.Transport = &http.Transport{
Proxy: http.ProxyURL(parsed),
}
}
} else {
log.Printf("common: invalid proxy URL %q: %v", proxy, err)
}
}
return client
}
// --- Message serialization ---
// openaiMessage is the wire-format message for OpenAI-compatible APIs.
// It mirrors protocoltypes.Message but omits SystemParts, which is an
// internal field that would be unknown to third-party endpoints.
type openaiMessage struct {
Role string `json:"role"`
Content string `json:"content"`
ReasoningContent string `json:"reasoning_content,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"`
}
// SerializeMessages converts internal Message structs to the OpenAI wire format.
// - Strips SystemParts (unknown to third-party endpoints)
// - Converts messages with Media to multipart content format (text + image_url parts)
// - Preserves ToolCallID, ToolCalls, and ReasoningContent for all messages
func SerializeMessages(messages []Message) []any {
out := make([]any, 0, len(messages))
for _, m := range messages {
if len(m.Media) == 0 {
out = append(out, openaiMessage{
Role: m.Role,
Content: m.Content,
ReasoningContent: m.ReasoningContent,
ToolCalls: m.ToolCalls,
ToolCallID: m.ToolCallID,
})
continue
}
// Multipart content format for messages with media
parts := make([]map[string]any, 0, 1+len(m.Media))
if m.Content != "" {
parts = append(parts, map[string]any{
"type": "text",
"text": m.Content,
})
}
for _, mediaURL := range m.Media {
if strings.HasPrefix(mediaURL, "data:image/") {
parts = append(parts, map[string]any{
"type": "image_url",
"image_url": map[string]any{
"url": mediaURL,
},
})
}
}
msg := map[string]any{
"role": m.Role,
"content": parts,
}
if m.ToolCallID != "" {
msg["tool_call_id"] = m.ToolCallID
}
if len(m.ToolCalls) > 0 {
msg["tool_calls"] = m.ToolCalls
}
if m.ReasoningContent != "" {
msg["reasoning_content"] = m.ReasoningContent
}
out = append(out, msg)
}
return out
}
// --- Response parsing ---
// ParseResponse parses a JSON chat completion response body into an LLMResponse.
func ParseResponse(body io.Reader) (*LLMResponse, error) {
var apiResponse struct {
Choices []struct {
Message struct {
Content string `json:"content"`
ReasoningContent string `json:"reasoning_content"`
Reasoning string `json:"reasoning"`
ReasoningDetails []ReasoningDetail `json:"reasoning_details"`
ToolCalls []struct {
ID string `json:"id"`
Type string `json:"type"`
Function *struct {
Name string `json:"name"`
Arguments json.RawMessage `json:"arguments"`
} `json:"function"`
ExtraContent *struct {
Google *struct {
ThoughtSignature string `json:"thought_signature"`
} `json:"google"`
} `json:"extra_content"`
} `json:"tool_calls"`
} `json:"message"`
FinishReason string `json:"finish_reason"`
} `json:"choices"`
Usage *UsageInfo `json:"usage"`
}
if err := json.NewDecoder(body).Decode(&apiResponse); err != nil {
return nil, fmt.Errorf("failed to decode response: %w", err)
}
if len(apiResponse.Choices) == 0 {
return &LLMResponse{
Content: "",
FinishReason: "stop",
}, nil
}
choice := apiResponse.Choices[0]
toolCalls := make([]ToolCall, 0, len(choice.Message.ToolCalls))
for _, tc := range choice.Message.ToolCalls {
arguments := make(map[string]any)
name := ""
// Extract thought_signature from Gemini/Google-specific extra content
thoughtSignature := ""
if tc.ExtraContent != nil && tc.ExtraContent.Google != nil {
thoughtSignature = tc.ExtraContent.Google.ThoughtSignature
}
if tc.Function != nil {
name = tc.Function.Name
arguments = DecodeToolCallArguments(tc.Function.Arguments, name)
}
toolCall := ToolCall{
ID: tc.ID,
Name: name,
Arguments: arguments,
ThoughtSignature: thoughtSignature,
}
if thoughtSignature != "" {
toolCall.ExtraContent = &ExtraContent{
Google: &GoogleExtra{
ThoughtSignature: thoughtSignature,
},
}
}
toolCalls = append(toolCalls, toolCall)
}
return &LLMResponse{
Content: choice.Message.Content,
ReasoningContent: choice.Message.ReasoningContent,
Reasoning: choice.Message.Reasoning,
ReasoningDetails: choice.Message.ReasoningDetails,
ToolCalls: toolCalls,
FinishReason: choice.FinishReason,
Usage: apiResponse.Usage,
}, nil
}
// DecodeToolCallArguments decodes a tool call's arguments from raw JSON.
func DecodeToolCallArguments(raw json.RawMessage, name string) map[string]any {
arguments := make(map[string]any)
raw = bytes.TrimSpace(raw)
if len(raw) == 0 || bytes.Equal(raw, []byte("null")) {
return arguments
}
var decoded any
if err := json.Unmarshal(raw, &decoded); err != nil {
log.Printf("common: failed to decode tool call arguments payload for %q: %v", name, err)
arguments["raw"] = string(raw)
return arguments
}
switch v := decoded.(type) {
case string:
if strings.TrimSpace(v) == "" {
return arguments
}
if err := json.Unmarshal([]byte(v), &arguments); err != nil {
log.Printf("common: failed to decode tool call arguments for %q: %v", name, err)
arguments["raw"] = v
}
return arguments
case map[string]any:
return v
default:
log.Printf("common: unsupported tool call arguments type for %q: %T", name, decoded)
arguments["raw"] = string(raw)
return arguments
}
}
// --- HTTP response helpers ---
// HandleErrorResponse reads a non-200 response body and returns an appropriate error.
func HandleErrorResponse(resp *http.Response, apiBase string) error {
contentType := resp.Header.Get("Content-Type")
body, readErr := io.ReadAll(io.LimitReader(resp.Body, 256))
if readErr != nil {
return fmt.Errorf("failed to read response: %w", readErr)
}
if LooksLikeHTML(body, contentType) {
return WrapHTMLResponseError(resp.StatusCode, body, contentType, apiBase)
}
return fmt.Errorf(
"API request failed:\n Status: %d\n Body: %s",
resp.StatusCode,
ResponsePreview(body, 128),
)
}
// ReadAndParseResponse peeks at the response body to detect HTML errors,
// then parses the JSON response into an LLMResponse.
func ReadAndParseResponse(resp *http.Response, apiBase string) (*LLMResponse, error) {
contentType := resp.Header.Get("Content-Type")
reader := bufio.NewReader(resp.Body)
prefix, err := reader.Peek(256)
if err != nil && err != io.EOF && err != bufio.ErrBufferFull {
return nil, fmt.Errorf("failed to inspect response: %w", err)
}
if LooksLikeHTML(prefix, contentType) {
return nil, WrapHTMLResponseError(resp.StatusCode, prefix, contentType, apiBase)
}
out, err := ParseResponse(reader)
if err != nil {
return nil, fmt.Errorf("failed to parse JSON response: %w", err)
}
return out, nil
}
// LooksLikeHTML checks if the response body appears to be HTML.
func LooksLikeHTML(body []byte, contentType string) bool {
contentType = strings.ToLower(strings.TrimSpace(contentType))
if strings.Contains(contentType, "text/html") || strings.Contains(contentType, "application/xhtml+xml") {
return true
}
prefix := bytes.ToLower(leadingTrimmedPrefix(body, 128))
return bytes.HasPrefix(prefix, []byte("<!doctype html")) ||
bytes.HasPrefix(prefix, []byte("<html")) ||
bytes.HasPrefix(prefix, []byte("<head")) ||
bytes.HasPrefix(prefix, []byte("<body"))
}
// WrapHTMLResponseError creates a descriptive error for HTML responses.
func WrapHTMLResponseError(statusCode int, body []byte, contentType, apiBase string) error {
respPreview := ResponsePreview(body, 128)
return fmt.Errorf(
"API request failed: %s returned HTML instead of JSON (content-type: %s); check api_base or proxy configuration.\n Status: %d\n Body: %s",
apiBase,
contentType,
statusCode,
respPreview,
)
}
// ResponsePreview returns a truncated preview of response body for error messages.
func ResponsePreview(body []byte, maxLen int) string {
trimmed := bytes.TrimSpace(body)
if len(trimmed) == 0 {
return "<empty>"
}
if len(trimmed) <= maxLen {
return string(trimmed)
}
return string(trimmed[:maxLen]) + "..."
}
func leadingTrimmedPrefix(body []byte, maxLen int) []byte {
i := 0
for i < len(body) {
switch body[i] {
case ' ', '\t', '\n', '\r', '\f', '\v':
i++
default:
end := i + maxLen
if end > len(body) {
end = len(body)
}
return body[i:end]
}
}
return nil
}
// --- Numeric helpers ---
// AsInt converts various numeric types to int.
func AsInt(v any) (int, bool) {
switch val := v.(type) {
case int:
return val, true
case int64:
return int(val), true
case float64:
return int(val), true
case float32:
return int(val), true
default:
return 0, false
}
}
// AsFloat converts various numeric types to float64.
func AsFloat(v any) (float64, bool) {
switch val := v.(type) {
case float64:
return val, true
case float32:
return float64(val), true
case int:
return float64(val), true
case int64:
return float64(val), true
default:
return 0, false
}
}
+558
View File
@@ -0,0 +1,558 @@
package common
import (
"encoding/json"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
)
// --- NewHTTPClient tests ---
func TestNewHTTPClient_DefaultTimeout(t *testing.T) {
client := NewHTTPClient("")
if client.Timeout != DefaultRequestTimeout {
t.Errorf("timeout = %v, want %v", client.Timeout, DefaultRequestTimeout)
}
}
func TestNewHTTPClient_WithProxy(t *testing.T) {
client := NewHTTPClient("http://127.0.0.1:8080")
transport, ok := client.Transport.(*http.Transport)
if !ok || transport == nil {
t.Fatalf("expected http.Transport with proxy, got %T", client.Transport)
}
req := &http.Request{URL: &url.URL{Scheme: "https", Host: "api.example.com"}}
gotProxy, err := transport.Proxy(req)
if err != nil {
t.Fatalf("proxy function error: %v", err)
}
if gotProxy == nil || gotProxy.String() != "http://127.0.0.1:8080" {
t.Errorf("proxy = %v, want http://127.0.0.1:8080", gotProxy)
}
}
func TestNewHTTPClient_NoProxy(t *testing.T) {
client := NewHTTPClient("")
if client.Transport != nil {
t.Errorf("expected nil transport without proxy, got %T", client.Transport)
}
}
func TestNewHTTPClient_InvalidProxy(t *testing.T) {
// Should not panic, just log and return client without proxy
client := NewHTTPClient("://bad-url")
if client == nil {
t.Fatal("expected non-nil client even with invalid proxy")
}
}
// --- SerializeMessages tests ---
func TestSerializeMessages_PlainText(t *testing.T) {
messages := []Message{
{Role: "user", Content: "hello"},
{Role: "assistant", Content: "hi", ReasoningContent: "thinking..."},
}
result := SerializeMessages(messages)
data, _ := json.Marshal(result)
var msgs []map[string]any
json.Unmarshal(data, &msgs)
if msgs[0]["content"] != "hello" {
t.Errorf("expected plain string content, got %v", msgs[0]["content"])
}
if msgs[1]["reasoning_content"] != "thinking..." {
t.Errorf("reasoning_content not preserved, got %v", msgs[1]["reasoning_content"])
}
}
func TestSerializeMessages_WithMedia(t *testing.T) {
messages := []Message{
{Role: "user", Content: "describe this", Media: []string{"data:image/png;base64,abc123"}},
}
result := SerializeMessages(messages)
data, _ := json.Marshal(result)
var msgs []map[string]any
json.Unmarshal(data, &msgs)
content, ok := msgs[0]["content"].([]any)
if !ok {
t.Fatalf("expected array content for media message, got %T", msgs[0]["content"])
}
if len(content) != 2 {
t.Fatalf("expected 2 content parts, got %d", len(content))
}
}
func TestSerializeMessages_MediaWithToolCallID(t *testing.T) {
messages := []Message{
{Role: "tool", Content: "result", Media: []string{"data:image/png;base64,xyz"}, ToolCallID: "call_1"},
}
result := SerializeMessages(messages)
data, _ := json.Marshal(result)
var msgs []map[string]any
json.Unmarshal(data, &msgs)
if msgs[0]["tool_call_id"] != "call_1" {
t.Errorf("tool_call_id not preserved, got %v", msgs[0]["tool_call_id"])
}
}
func TestSerializeMessages_StripsSystemParts(t *testing.T) {
messages := []Message{
{
Role: "system",
Content: "you are helpful",
SystemParts: []protocoltypes.ContentBlock{
{Type: "text", Text: "you are helpful"},
},
},
}
result := SerializeMessages(messages)
data, _ := json.Marshal(result)
if strings.Contains(string(data), "system_parts") {
t.Error("system_parts should not appear in serialized output")
}
}
// --- ParseResponse tests ---
func TestParseResponse_BasicContent(t *testing.T) {
body := `{"choices":[{"message":{"content":"hello world"},"finish_reason":"stop"}]}`
out, err := ParseResponse(strings.NewReader(body))
if err != nil {
t.Fatalf("ParseResponse() error = %v", err)
}
if out.Content != "hello world" {
t.Errorf("Content = %q, want %q", out.Content, "hello world")
}
if out.FinishReason != "stop" {
t.Errorf("FinishReason = %q, want %q", out.FinishReason, "stop")
}
}
func TestParseResponse_EmptyChoices(t *testing.T) {
body := `{"choices":[]}`
out, err := ParseResponse(strings.NewReader(body))
if err != nil {
t.Fatalf("ParseResponse() error = %v", err)
}
if out.Content != "" {
t.Errorf("Content = %q, want empty", out.Content)
}
if out.FinishReason != "stop" {
t.Errorf("FinishReason = %q, want %q", out.FinishReason, "stop")
}
}
func TestParseResponse_WithToolCalls(t *testing.T) {
body := `{"choices":[{"message":{"content":"","tool_calls":[{"id":"call_1","type":"function","function":{"name":"get_weather","arguments":"{\"city\":\"SF\"}"}}]},"finish_reason":"tool_calls"}]}`
out, err := ParseResponse(strings.NewReader(body))
if err != nil {
t.Fatalf("ParseResponse() error = %v", err)
}
if len(out.ToolCalls) != 1 {
t.Fatalf("len(ToolCalls) = %d, want 1", len(out.ToolCalls))
}
if out.ToolCalls[0].Name != "get_weather" {
t.Errorf("ToolCalls[0].Name = %q, want %q", out.ToolCalls[0].Name, "get_weather")
}
if out.ToolCalls[0].Arguments["city"] != "SF" {
t.Errorf("ToolCalls[0].Arguments[city] = %v, want SF", out.ToolCalls[0].Arguments["city"])
}
}
func TestParseResponse_WithUsage(t *testing.T) {
body := `{"choices":[{"message":{"content":"ok"},"finish_reason":"stop"}],"usage":{"prompt_tokens":10,"completion_tokens":5,"total_tokens":15}}`
out, err := ParseResponse(strings.NewReader(body))
if err != nil {
t.Fatalf("ParseResponse() error = %v", err)
}
if out.Usage == nil {
t.Fatal("Usage is nil")
}
if out.Usage.PromptTokens != 10 {
t.Errorf("PromptTokens = %d, want 10", out.Usage.PromptTokens)
}
}
func TestParseResponse_WithReasoningContent(t *testing.T) {
body := `{"choices":[{"message":{"content":"2","reasoning_content":"Let me think... 1+1=2"},"finish_reason":"stop"}]}`
out, err := ParseResponse(strings.NewReader(body))
if err != nil {
t.Fatalf("ParseResponse() error = %v", err)
}
if out.ReasoningContent != "Let me think... 1+1=2" {
t.Errorf("ReasoningContent = %q, want %q", out.ReasoningContent, "Let me think... 1+1=2")
}
}
func TestParseResponse_InvalidJSON(t *testing.T) {
_, err := ParseResponse(strings.NewReader("not json"))
if err == nil {
t.Fatal("expected error for invalid JSON")
}
}
// --- DecodeToolCallArguments tests ---
func TestDecodeToolCallArguments_ObjectJSON(t *testing.T) {
raw := json.RawMessage(`{"city":"Seattle","units":"metric"}`)
args := DecodeToolCallArguments(raw, "test")
if args["city"] != "Seattle" {
t.Errorf("city = %v, want Seattle", args["city"])
}
if args["units"] != "metric" {
t.Errorf("units = %v, want metric", args["units"])
}
}
func TestDecodeToolCallArguments_StringJSON(t *testing.T) {
raw := json.RawMessage(`"{\"city\":\"SF\"}"`)
args := DecodeToolCallArguments(raw, "test")
if args["city"] != "SF" {
t.Errorf("city = %v, want SF", args["city"])
}
}
func TestDecodeToolCallArguments_EmptyInput(t *testing.T) {
args := DecodeToolCallArguments(nil, "test")
if len(args) != 0 {
t.Errorf("expected empty map, got %v", args)
}
}
func TestDecodeToolCallArguments_NullInput(t *testing.T) {
args := DecodeToolCallArguments(json.RawMessage(`null`), "test")
if len(args) != 0 {
t.Errorf("expected empty map, got %v", args)
}
}
func TestDecodeToolCallArguments_InvalidJSON(t *testing.T) {
args := DecodeToolCallArguments(json.RawMessage(`not-json`), "test")
if _, ok := args["raw"]; !ok {
t.Error("expected 'raw' fallback key for invalid JSON")
}
}
func TestDecodeToolCallArguments_EmptyStringJSON(t *testing.T) {
args := DecodeToolCallArguments(json.RawMessage(`" "`), "test")
if len(args) != 0 {
t.Errorf("expected empty map for whitespace string, got %v", args)
}
}
// --- HandleErrorResponse tests ---
func TestHandleErrorResponse_JSONError(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte(`{"error":"bad request"}`))
}))
defer server.Close()
resp, err := http.Get(server.URL)
if err != nil {
t.Fatalf("http.Get() error = %v", err)
}
defer resp.Body.Close()
err = HandleErrorResponse(resp, server.URL)
if err == nil {
t.Fatal("expected error")
}
if !strings.Contains(err.Error(), "400") {
t.Errorf("error should contain status code, got %v", err)
}
if strings.Contains(err.Error(), "HTML") {
t.Errorf("should not mention HTML for JSON error, got %v", err)
}
}
func TestHandleErrorResponse_HTMLError(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/html")
w.WriteHeader(http.StatusBadGateway)
w.Write([]byte("<!DOCTYPE html><html><body>bad gateway</body></html>"))
}))
defer server.Close()
resp, err := http.Get(server.URL)
if err != nil {
t.Fatalf("http.Get() error = %v", err)
}
defer resp.Body.Close()
err = HandleErrorResponse(resp, server.URL)
if err == nil {
t.Fatal("expected error")
}
if !strings.Contains(err.Error(), "HTML instead of JSON") {
t.Errorf("expected HTML error message, got %v", err)
}
}
// --- ReadAndParseResponse tests ---
func TestReadAndParseResponse_ValidJSON(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(`{"choices":[{"message":{"content":"ok"},"finish_reason":"stop"}]}`))
}))
defer server.Close()
resp, err := http.Get(server.URL)
if err != nil {
t.Fatalf("http.Get() error = %v", err)
}
defer resp.Body.Close()
out, err := ReadAndParseResponse(resp, server.URL)
if err != nil {
t.Fatalf("ReadAndParseResponse() error = %v", err)
}
if out.Content != "ok" {
t.Errorf("Content = %q, want %q", out.Content, "ok")
}
}
func TestReadAndParseResponse_HTMLResponse(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/html")
w.Write([]byte("<!DOCTYPE html><html><body>login page</body></html>"))
}))
defer server.Close()
resp, err := http.Get(server.URL)
if err != nil {
t.Fatalf("http.Get() error = %v", err)
}
defer resp.Body.Close()
_, err = ReadAndParseResponse(resp, server.URL)
if err == nil {
t.Fatal("expected error for HTML response")
}
if !strings.Contains(err.Error(), "HTML instead of JSON") {
t.Errorf("expected HTML error, got %v", err)
}
}
// --- LooksLikeHTML tests ---
func TestLooksLikeHTML_ContentTypeHTML(t *testing.T) {
if !LooksLikeHTML(nil, "text/html; charset=utf-8") {
t.Error("expected true for text/html content type")
}
}
func TestLooksLikeHTML_ContentTypeXHTML(t *testing.T) {
if !LooksLikeHTML(nil, "application/xhtml+xml") {
t.Error("expected true for xhtml content type")
}
}
func TestLooksLikeHTML_BodyPrefix(t *testing.T) {
tests := []struct {
name string
body string
}{
{"doctype", "<!DOCTYPE html><html>"},
{"html tag", "<html><body>"},
{"head tag", "<head><title>"},
{"body tag", "<body>content"},
{"whitespace before", " \n\t<!DOCTYPE html>"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if !LooksLikeHTML([]byte(tt.body), "application/json") {
t.Errorf("expected true for body %q", tt.body)
}
})
}
}
func TestLooksLikeHTML_NotHTML(t *testing.T) {
if LooksLikeHTML([]byte(`{"error":"bad"}`), "application/json") {
t.Error("expected false for JSON body")
}
}
// --- ResponsePreview tests ---
func TestResponsePreview_Short(t *testing.T) {
got := ResponsePreview([]byte("hello"), 128)
if got != "hello" {
t.Errorf("got %q, want %q", got, "hello")
}
}
func TestResponsePreview_Truncated(t *testing.T) {
body := strings.Repeat("a", 200)
got := ResponsePreview([]byte(body), 128)
if len(got) != 131 { // 128 + "..."
t.Errorf("len = %d, want 131", len(got))
}
if !strings.HasSuffix(got, "...") {
t.Error("expected ... suffix")
}
}
func TestResponsePreview_Empty(t *testing.T) {
got := ResponsePreview([]byte(""), 128)
if got != "<empty>" {
t.Errorf("got %q, want %q", got, "<empty>")
}
}
func TestResponsePreview_Whitespace(t *testing.T) {
got := ResponsePreview([]byte(" \n\t "), 128)
if got != "<empty>" {
t.Errorf("got %q, want %q for whitespace-only body", got, "<empty>")
}
}
// --- AsInt tests ---
func TestAsInt(t *testing.T) {
tests := []struct {
name string
val any
want int
ok bool
}{
{"int", 42, 42, true},
{"int64", int64(99), 99, true},
{"float64", float64(512), 512, true},
{"float32", float32(256), 256, true},
{"string", "nope", 0, false},
{"nil", nil, 0, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, ok := AsInt(tt.val)
if ok != tt.ok || got != tt.want {
t.Errorf("AsInt(%v) = (%d, %v), want (%d, %v)", tt.val, got, ok, tt.want, tt.ok)
}
})
}
}
// --- AsFloat tests ---
func TestAsFloat(t *testing.T) {
tests := []struct {
name string
val any
want float64
ok bool
}{
{"float64", float64(0.7), 0.7, true},
{"float32", float32(0.5), float64(float32(0.5)), true},
{"int", 1, 1.0, true},
{"int64", int64(100), 100.0, true},
{"string", "nope", 0, false},
{"nil", nil, 0, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, ok := AsFloat(tt.val)
if ok != tt.ok || got != tt.want {
t.Errorf("AsFloat(%v) = (%f, %v), want (%f, %v)", tt.val, got, ok, tt.want, tt.ok)
}
})
}
}
// --- WrapHTMLResponseError tests ---
func TestWrapHTMLResponseError(t *testing.T) {
err := WrapHTMLResponseError(502, []byte("<html>bad</html>"), "text/html", "https://api.example.com")
if err == nil {
t.Fatal("expected error")
}
msg := err.Error()
if !strings.Contains(msg, "502") {
t.Errorf("expected status code in error, got %v", msg)
}
if !strings.Contains(msg, "https://api.example.com") {
t.Errorf("expected api base in error, got %v", msg)
}
if !strings.Contains(msg, "HTML instead of JSON") {
t.Errorf("expected HTML mention in error, got %v", msg)
}
}
// --- HandleErrorResponse with read failure ---
func TestHandleErrorResponse_EmptyBody(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusInternalServerError)
// empty body
}))
defer server.Close()
resp, err := http.Get(server.URL)
if err != nil {
t.Fatalf("http.Get() error = %v", err)
}
defer resp.Body.Close()
err = HandleErrorResponse(resp, server.URL)
if err == nil {
t.Fatal("expected error")
}
if !strings.Contains(err.Error(), "500") {
t.Errorf("expected status code, got %v", err)
}
}
// --- ReadAndParseResponse with invalid JSON ---
func TestReadAndParseResponse_InvalidJSON(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Write([]byte("not valid json"))
}))
defer server.Close()
resp, err := http.Get(server.URL)
if err != nil {
t.Fatalf("http.Get() error = %v", err)
}
defer resp.Body.Close()
_, err = ReadAndParseResponse(resp, server.URL)
if err == nil {
t.Fatal("expected error for invalid JSON")
}
}
// --- ParseResponse with thought_signature (Google/Gemini) ---
func TestParseResponse_WithThoughtSignature(t *testing.T) {
body := `{"choices":[{"message":{"content":"","tool_calls":[{"id":"call_1","type":"function","function":{"name":"test_tool","arguments":"{}"},"extra_content":{"google":{"thought_signature":"sig123"}}}]},"finish_reason":"tool_calls"}]}`
out, err := ParseResponse(strings.NewReader(body))
if err != nil {
t.Fatalf("ParseResponse() error = %v", err)
}
if len(out.ToolCalls) != 1 {
t.Fatalf("len(ToolCalls) = %d, want 1", len(out.ToolCalls))
}
if out.ToolCalls[0].ThoughtSignature != "sig123" {
t.Errorf("ThoughtSignature = %q, want %q", out.ToolCalls[0].ThoughtSignature, "sig123")
}
if out.ToolCalls[0].ExtraContent == nil || out.ToolCalls[0].ExtraContent.Google == nil {
t.Fatal("ExtraContent.Google is nil")
}
if out.ToolCalls[0].ExtraContent.Google.ThoughtSignature != "sig123" {
t.Errorf("ExtraContent.Google.ThoughtSignature = %q, want %q",
out.ToolCalls[0].ExtraContent.Google.ThoughtSignature, "sig123")
}
}
+19
View File
@@ -11,6 +11,7 @@ import (
"github.com/sipeed/picoclaw/pkg/config"
anthropicmessages "github.com/sipeed/picoclaw/pkg/providers/anthropic_messages"
"github.com/sipeed/picoclaw/pkg/providers/azure"
)
// createClaudeAuthProvider creates a Claude provider using OAuth credentials from auth store.
@@ -94,6 +95,24 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
cfg.RequestTimeout,
), modelID, nil
case "azure", "azure-openai":
// Azure OpenAI uses deployment-based URLs, api-key header auth,
// and always sends max_completion_tokens.
if cfg.APIKey == "" {
return nil, "", fmt.Errorf("api_key is required for azure protocol")
}
if cfg.APIBase == "" {
return nil, "", fmt.Errorf(
"api_base is required for azure protocol (e.g., https://your-resource.openai.azure.com)",
)
}
return azure.NewProviderWithTimeout(
cfg.APIKey,
cfg.APIBase,
cfg.Proxy,
cfg.RequestTimeout,
), modelID, nil
case "litellm", "openrouter", "groq", "zhipu", "gemini", "nvidia",
"ollama", "moonshot", "shengsuanyun", "deepseek", "cerebras",
"vivgrid", "volcengine", "vllm", "qwen", "mistral", "avian",
+72
View File
@@ -64,6 +64,12 @@ func TestExtractProtocol(t *testing.T) {
wantProtocol: "nvidia",
wantModelID: "meta/llama-3.1-8b",
},
{
name: "azure with prefix",
model: "azure/my-gpt5-deployment",
wantProtocol: "azure",
wantModelID: "my-gpt5-deployment",
},
}
for _, tt := range tests {
@@ -371,3 +377,69 @@ func TestCreateProviderFromConfig_RequestTimeoutPropagation(t *testing.T) {
t.Fatalf("Chat() error = %q, want timeout-related error", errMsg)
}
}
func TestCreateProviderFromConfig_Azure(t *testing.T) {
cfg := &config.ModelConfig{
ModelName: "azure-gpt5",
Model: "azure/my-gpt5-deployment",
APIKey: "test-azure-key",
APIBase: "https://my-resource.openai.azure.com",
}
provider, modelID, err := CreateProviderFromConfig(cfg)
if err != nil {
t.Fatalf("CreateProviderFromConfig() error = %v", err)
}
if provider == nil {
t.Fatal("CreateProviderFromConfig() returned nil provider")
}
if modelID != "my-gpt5-deployment" {
t.Errorf("modelID = %q, want %q", modelID, "my-gpt5-deployment")
}
}
func TestCreateProviderFromConfig_AzureOpenAIAlias(t *testing.T) {
cfg := &config.ModelConfig{
ModelName: "azure-gpt4",
Model: "azure-openai/my-deployment",
APIKey: "test-azure-key",
APIBase: "https://my-resource.openai.azure.com",
}
provider, modelID, err := CreateProviderFromConfig(cfg)
if err != nil {
t.Fatalf("CreateProviderFromConfig() error = %v", err)
}
if provider == nil {
t.Fatal("CreateProviderFromConfig() returned nil provider")
}
if modelID != "my-deployment" {
t.Errorf("modelID = %q, want %q", modelID, "my-deployment")
}
}
func TestCreateProviderFromConfig_AzureMissingAPIKey(t *testing.T) {
cfg := &config.ModelConfig{
ModelName: "azure-gpt5",
Model: "azure/my-gpt5-deployment",
APIBase: "https://my-resource.openai.azure.com",
}
_, _, err := CreateProviderFromConfig(cfg)
if err == nil {
t.Fatal("CreateProviderFromConfig() expected error for missing API key")
}
}
func TestCreateProviderFromConfig_AzureMissingAPIBase(t *testing.T) {
cfg := &config.ModelConfig{
ModelName: "azure-gpt5",
Model: "azure/my-gpt5-deployment",
APIKey: "test-azure-key",
}
_, _, err := CreateProviderFromConfig(cfg)
if err == nil {
t.Fatal("CreateProviderFromConfig() expected error for missing API base")
}
}
+8 -319
View File
@@ -1,18 +1,16 @@
package openai_compat
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"net/url"
"strings"
"time"
"github.com/sipeed/picoclaw/pkg/providers/common"
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
)
@@ -38,7 +36,7 @@ type Provider struct {
type Option func(*Provider)
const defaultRequestTimeout = 120 * time.Second
const defaultRequestTimeout = common.DefaultRequestTimeout
func WithMaxTokensField(maxTokensField string) Option {
return func(p *Provider) {
@@ -55,25 +53,10 @@ func WithRequestTimeout(timeout time.Duration) Option {
}
func NewProvider(apiKey, apiBase, proxy string, opts ...Option) *Provider {
client := &http.Client{
Timeout: defaultRequestTimeout,
}
if proxy != "" {
parsed, err := url.Parse(proxy)
if err == nil {
client.Transport = &http.Transport{
Proxy: http.ProxyURL(parsed),
}
} else {
log.Printf("openai_compat: invalid proxy URL %q: %v", proxy, err)
}
}
p := &Provider{
apiKey: apiKey,
apiBase: strings.TrimRight(apiBase, "/"),
httpClient: client,
httpClient: common.NewHTTPClient(proxy),
}
for _, opt := range opts {
@@ -117,7 +100,7 @@ func (p *Provider) Chat(
requestBody := map[string]any{
"model": model,
"messages": serializeMessages(messages),
"messages": common.SerializeMessages(messages),
}
if len(tools) > 0 {
@@ -125,7 +108,7 @@ func (p *Provider) Chat(
requestBody["tool_choice"] = "auto"
}
if maxTokens, ok := asInt(options["max_tokens"]); ok {
if maxTokens, ok := common.AsInt(options["max_tokens"]); ok {
// Use configured maxTokensField if specified, otherwise fallback to model-based detection
fieldName := p.maxTokensField
if fieldName == "" {
@@ -141,7 +124,7 @@ func (p *Provider) Chat(
requestBody[fieldName] = maxTokens
}
if temperature, ok := asFloat(options["temperature"]); ok {
if temperature, ok := common.AsFloat(options["temperature"]); ok {
lowerModel := strings.ToLower(model)
// Kimi k2 models only support temperature=1.
if strings.Contains(lowerModel, "kimi") && strings.Contains(lowerModel, "k2") {
@@ -185,275 +168,11 @@ func (p *Provider) Chat(
}
defer resp.Body.Close()
contentType := resp.Header.Get("Content-Type")
// Non-200: read a prefix to tell HTML error page apart from JSON error body.
if resp.StatusCode != http.StatusOK {
body, readErr := io.ReadAll(io.LimitReader(resp.Body, 256))
if readErr != nil {
return nil, fmt.Errorf("failed to read response: %w", readErr)
}
if looksLikeHTML(body, contentType) {
return nil, wrapHTMLResponseError(resp.StatusCode, body, contentType, p.apiBase)
}
return nil, fmt.Errorf(
"API request failed:\n Status: %d\n Body: %s",
resp.StatusCode,
responsePreview(body, 128),
)
return nil, common.HandleErrorResponse(resp, p.apiBase)
}
// Peek without consuming so the full stream reaches the JSON decoder.
reader := bufio.NewReader(resp.Body)
prefix, err := reader.Peek(256) // io.EOF/ErrBufferFull are normal; only real errors abort
if err != nil && err != io.EOF && err != bufio.ErrBufferFull {
return nil, fmt.Errorf("failed to inspect response: %w", err)
}
if looksLikeHTML(prefix, contentType) {
return nil, wrapHTMLResponseError(resp.StatusCode, prefix, contentType, p.apiBase)
}
out, err := parseResponse(reader)
if err != nil {
return nil, fmt.Errorf("failed to parse JSON response: %w", err)
}
return out, nil
}
func wrapHTMLResponseError(statusCode int, body []byte, contentType, apiBase string) error {
respPreview := responsePreview(body, 128)
return fmt.Errorf(
"API request failed: %s returned HTML instead of JSON (content-type: %s); check api_base or proxy configuration.\n Status: %d\n Body: %s",
apiBase,
contentType,
statusCode,
respPreview,
)
}
func looksLikeHTML(body []byte, contentType string) bool {
contentType = strings.ToLower(strings.TrimSpace(contentType))
if strings.Contains(contentType, "text/html") || strings.Contains(contentType, "application/xhtml+xml") {
return true
}
prefix := bytes.ToLower(leadingTrimmedPrefix(body, 128))
return bytes.HasPrefix(prefix, []byte("<!doctype html")) ||
bytes.HasPrefix(prefix, []byte("<html")) ||
bytes.HasPrefix(prefix, []byte("<head")) ||
bytes.HasPrefix(prefix, []byte("<body"))
}
func leadingTrimmedPrefix(body []byte, maxLen int) []byte {
i := 0
for i < len(body) {
switch body[i] {
case ' ', '\t', '\n', '\r', '\f', '\v':
i++
default:
end := i + maxLen
if end > len(body) {
end = len(body)
}
return body[i:end]
}
}
return nil
}
func responsePreview(body []byte, maxLen int) string {
trimmed := bytes.TrimSpace(body)
if len(trimmed) == 0 {
return "<empty>"
}
if len(trimmed) <= maxLen {
return string(trimmed)
}
return string(trimmed[:maxLen]) + "..."
}
func parseResponse(body io.Reader) (*LLMResponse, error) {
var apiResponse struct {
Choices []struct {
Message struct {
Content string `json:"content"`
ReasoningContent string `json:"reasoning_content"`
Reasoning string `json:"reasoning"`
ReasoningDetails []ReasoningDetail `json:"reasoning_details"`
ToolCalls []struct {
ID string `json:"id"`
Type string `json:"type"`
Function *struct {
Name string `json:"name"`
Arguments json.RawMessage `json:"arguments"`
} `json:"function"`
ExtraContent *struct {
Google *struct {
ThoughtSignature string `json:"thought_signature"`
} `json:"google"`
} `json:"extra_content"`
} `json:"tool_calls"`
} `json:"message"`
FinishReason string `json:"finish_reason"`
} `json:"choices"`
Usage *UsageInfo `json:"usage"`
}
if err := json.NewDecoder(body).Decode(&apiResponse); err != nil {
return nil, fmt.Errorf("failed to decode response: %w", err)
}
if len(apiResponse.Choices) == 0 {
return &LLMResponse{
Content: "",
FinishReason: "stop",
}, nil
}
choice := apiResponse.Choices[0]
toolCalls := make([]ToolCall, 0, len(choice.Message.ToolCalls))
for _, tc := range choice.Message.ToolCalls {
arguments := make(map[string]any)
name := ""
// Extract thought_signature from Gemini/Google-specific extra content
thoughtSignature := ""
if tc.ExtraContent != nil && tc.ExtraContent.Google != nil {
thoughtSignature = tc.ExtraContent.Google.ThoughtSignature
}
if tc.Function != nil {
name = tc.Function.Name
arguments = decodeToolCallArguments(tc.Function.Arguments, name)
}
// Build ToolCall with ExtraContent for Gemini 3 thought_signature persistence
toolCall := ToolCall{
ID: tc.ID,
Name: name,
Arguments: arguments,
ThoughtSignature: thoughtSignature,
}
if thoughtSignature != "" {
toolCall.ExtraContent = &ExtraContent{
Google: &GoogleExtra{
ThoughtSignature: thoughtSignature,
},
}
}
toolCalls = append(toolCalls, toolCall)
}
return &LLMResponse{
Content: choice.Message.Content,
ReasoningContent: choice.Message.ReasoningContent,
Reasoning: choice.Message.Reasoning,
ReasoningDetails: choice.Message.ReasoningDetails,
ToolCalls: toolCalls,
FinishReason: choice.FinishReason,
Usage: apiResponse.Usage,
}, nil
}
func decodeToolCallArguments(raw json.RawMessage, name string) map[string]any {
arguments := make(map[string]any)
raw = bytes.TrimSpace(raw)
if len(raw) == 0 || bytes.Equal(raw, []byte("null")) {
return arguments
}
var decoded any
if err := json.Unmarshal(raw, &decoded); err != nil {
log.Printf("openai_compat: failed to decode tool call arguments payload for %q: %v", name, err)
arguments["raw"] = string(raw)
return arguments
}
switch v := decoded.(type) {
case string:
if strings.TrimSpace(v) == "" {
return arguments
}
if err := json.Unmarshal([]byte(v), &arguments); err != nil {
log.Printf("openai_compat: failed to decode tool call arguments for %q: %v", name, err)
arguments["raw"] = v
}
return arguments
case map[string]any:
return v
default:
log.Printf("openai_compat: unsupported tool call arguments type for %q: %T", name, decoded)
arguments["raw"] = string(raw)
return arguments
}
}
// openaiMessage is the wire-format message for OpenAI-compatible APIs.
// It mirrors protocoltypes.Message but omits SystemParts, which is an
// internal field that would be unknown to third-party endpoints.
type openaiMessage struct {
Role string `json:"role"`
Content string `json:"content"`
ReasoningContent string `json:"reasoning_content,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"`
}
// serializeMessages converts internal Message structs to the OpenAI wire format.
// - Strips SystemParts (unknown to third-party endpoints)
// - Converts messages with Media to multipart content format (text + image_url parts)
// - Preserves ToolCallID, ToolCalls, and ReasoningContent for all messages
func serializeMessages(messages []Message) []any {
out := make([]any, 0, len(messages))
for _, m := range messages {
if len(m.Media) == 0 {
out = append(out, openaiMessage{
Role: m.Role,
Content: m.Content,
ReasoningContent: m.ReasoningContent,
ToolCalls: m.ToolCalls,
ToolCallID: m.ToolCallID,
})
continue
}
// Multipart content format for messages with media
parts := make([]map[string]any, 0, 1+len(m.Media))
if m.Content != "" {
parts = append(parts, map[string]any{
"type": "text",
"text": m.Content,
})
}
for _, mediaURL := range m.Media {
if strings.HasPrefix(mediaURL, "data:image/") {
parts = append(parts, map[string]any{
"type": "image_url",
"image_url": map[string]any{
"url": mediaURL,
},
})
}
}
msg := map[string]any{
"role": m.Role,
"content": parts,
}
if m.ToolCallID != "" {
msg["tool_call_id"] = m.ToolCallID
}
if len(m.ToolCalls) > 0 {
msg["tool_calls"] = m.ToolCalls
}
if m.ReasoningContent != "" {
msg["reasoning_content"] = m.ReasoningContent
}
out = append(out, msg)
}
return out
return common.ReadAndParseResponse(resp, p.apiBase)
}
func normalizeModel(model, apiBase string) string {
@@ -476,36 +195,6 @@ func normalizeModel(model, apiBase string) string {
}
}
func asInt(v any) (int, bool) {
switch val := v.(type) {
case int:
return val, true
case int64:
return int(val), true
case float64:
return int(val), true
case float32:
return int(val), true
default:
return 0, false
}
}
func asFloat(v any) (float64, bool) {
switch val := v.(type) {
case float64:
return val, true
case float32:
return float64(val), true
case int:
return float64(val), true
case int64:
return float64(val), true
default:
return 0, false
}
}
// supportsPromptCacheKey reports whether the given API base is known to
// support the prompt_cache_key request field. Currently only OpenAI's own
// API and Azure OpenAI support this. All other OpenAI-compatible providers
+5 -4
View File
@@ -12,6 +12,7 @@ import (
"testing"
"time"
"github.com/sipeed/picoclaw/pkg/providers/common"
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
)
@@ -648,7 +649,7 @@ func TestSerializeMessages_PlainText(t *testing.T) {
{Role: "user", Content: "hello"},
{Role: "assistant", Content: "hi", ReasoningContent: "thinking..."},
}
result := serializeMessages(messages)
result := common.SerializeMessages(messages)
data, err := json.Marshal(result)
if err != nil {
@@ -670,7 +671,7 @@ func TestSerializeMessages_WithMedia(t *testing.T) {
messages := []protocoltypes.Message{
{Role: "user", Content: "describe this", Media: []string{"data:image/png;base64,abc123"}},
}
result := serializeMessages(messages)
result := common.SerializeMessages(messages)
data, _ := json.Marshal(result)
var msgs []map[string]any
@@ -703,7 +704,7 @@ func TestSerializeMessages_MediaWithToolCallID(t *testing.T) {
messages := []protocoltypes.Message{
{Role: "tool", Content: "image result", Media: []string{"data:image/png;base64,xyz"}, ToolCallID: "call_1"},
}
result := serializeMessages(messages)
result := common.SerializeMessages(messages)
data, _ := json.Marshal(result)
var msgs []map[string]any
@@ -833,7 +834,7 @@ func TestSerializeMessages_StripsSystemParts(t *testing.T) {
},
},
}
result := serializeMessages(messages)
result := common.SerializeMessages(messages)
data, _ := json.Marshal(result)
raw := string(data)