chore(provider): use openai responses api for azure openai endpoints (#2110)

Migrate Azure OpenAI provider from legacy Chat Completions API to the OpenAI Responses API.

- Switch API endpoint from `/openai/deployments/{deployment}/chat/completions` to `/openai/v1/responses`
- Change auth header from `Api-Key` to `Authorization: Bearer`
- Use `responses.ResponseNewParams` SDK types for request construction
- Extract shared Responses API utilities into `openai_responses_common` package
- Deduplicate 178 lines from codex_provider.go by reusing shared package
- Add 593 lines of comprehensive test coverage for the shared package

Closes #2111

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Kunal Karmakar
2026-03-28 18:20:24 +05:30
committed by GitHub
parent 026a1339c7
commit 1809d04905
6 changed files with 1126 additions and 260 deletions
+34 -24
View File
@@ -10,7 +10,11 @@ import (
"strings"
"time"
"github.com/openai/openai-go/v3"
"github.com/openai/openai-go/v3/responses"
"github.com/sipeed/picoclaw/pkg/providers/common"
orc "github.com/sipeed/picoclaw/pkg/providers/openai_responses_common"
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
)
@@ -21,14 +25,12 @@ type (
)
const (
// azureAPIVersion is the Azure OpenAI API version used for all requests.
azureAPIVersion = "2024-10-21"
defaultRequestTimeout = common.DefaultRequestTimeout
)
// Provider implements the LLM provider interface for Azure OpenAI endpoints.
// It handles Azure-specific authentication (api-key header), URL construction
// (deployment-based), and request body formatting (max_completion_tokens, no model field).
// It handles Azure-specific authentication (Bearer token), URL construction
// (Responses API), and request/response formatting.
type Provider struct {
apiKey string
apiBase string
@@ -72,8 +74,8 @@ func NewProviderWithTimeout(apiKey, apiBase, proxy string, requestTimeoutSeconds
)
}
// Chat sends a chat completion request to the Azure OpenAI endpoint.
// The model parameter is used as the Azure deployment name in the URL.
// Chat sends a request to the Azure OpenAI Responses API endpoint.
// The model parameter is passed in the request body.
func (p *Provider) Chat(
ctx context.Context,
messages []Message,
@@ -85,34 +87,43 @@ func (p *Provider) Chat(
return nil, fmt.Errorf("Azure API base not configured")
}
// model is the deployment name for Azure OpenAI
deployment := model
// Build Azure-specific URL safely using url.JoinPath and query encoding
// to prevent path traversal or query injection via deployment names.
base, err := url.JoinPath(p.apiBase, "openai/deployments", deployment, "chat/completions")
requestURL, err := url.JoinPath(p.apiBase, "openai/v1/responses")
if err != nil {
return nil, fmt.Errorf("failed to build Azure request URL: %w", err)
}
requestURL := base + "?api-version=" + azureAPIVersion
// Build request body — no "model" field (Azure infers from deployment URL)
requestBody := map[string]any{
"messages": common.SerializeMessages(messages),
input, instructions := orc.TranslateMessages(messages)
requestBody := responses.ResponseNewParams{
Model: model,
Input: responses.ResponseNewParamsInputUnion{
OfInputItemList: input,
},
Store: openai.Opt(false),
}
if instructions != "" {
requestBody.Instructions = openai.Opt(instructions)
}
if len(tools) > 0 {
requestBody["tools"] = tools
requestBody["tool_choice"] = "auto"
enableWebSearch, _ := options["native_search"].(bool)
requestBody.Tools = orc.TranslateTools(tools, enableWebSearch)
requestBody.ToolChoice = responses.ResponseNewParamsToolChoiceUnion{
OfToolChoiceMode: openai.Opt(responses.ToolChoiceOptionsAuto),
}
}
// Azure OpenAI always uses max_completion_tokens
if maxTokens, ok := common.AsInt(options["max_tokens"]); ok {
requestBody["max_completion_tokens"] = maxTokens
requestBody.MaxOutputTokens = openai.Opt(int64(maxTokens))
}
if temperature, ok := common.AsFloat(options["temperature"]); ok {
requestBody["temperature"] = temperature
requestBody.Temperature = openai.Opt(temperature)
}
if cacheKey, ok := options["prompt_cache_key"].(string); ok && cacheKey != "" {
requestBody.PromptCacheKey = openai.Opt(cacheKey)
}
jsonData, err := json.Marshal(requestBody)
@@ -125,10 +136,9 @@ func (p *Provider) Chat(
return nil, fmt.Errorf("failed to create request: %w", err)
}
// Azure uses api-key header instead of Authorization: Bearer
req.Header.Set("Content-Type", "application/json")
if p.apiKey != "" {
req.Header.Set("Api-Key", p.apiKey)
req.Header.Set("Authorization", "Bearer "+p.apiKey)
}
resp, err := p.httpClient.Do(req)
@@ -141,7 +151,7 @@ func (p *Provider) Chat(
return nil, common.HandleErrorResponse(resp, p.apiBase)
}
return common.ReadAndParseResponse(resp, p.apiBase)
return orc.ParseResponseBody(resp.Body)
}
// GetDefaultModel returns an empty string as Azure deployments are user-configured.
+198 -50
View File
@@ -6,17 +6,31 @@ import (
"net/http/httptest"
"testing"
"time"
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
)
// writeValidResponse writes a minimal valid Azure OpenAI chat completion response.
// writeValidResponse writes a minimal valid Responses API response.
func writeValidResponse(w http.ResponseWriter) {
resp := map[string]any{
"choices": []map[string]any{
"id": "resp_test",
"object": "response",
"status": "completed",
"output": []map[string]any{
{
"message": map[string]any{"content": "ok"},
"finish_reason": "stop",
"type": "message",
"content": []map[string]any{
{"type": "output_text", "text": "ok"},
},
},
},
"usage": map[string]any{
"input_tokens": 5,
"output_tokens": 2,
"total_tokens": 7,
"input_tokens_details": map[string]any{"cached_tokens": 0},
"output_tokens_details": map[string]any{"reasoning_tokens": 0},
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
@@ -24,11 +38,9 @@ func writeValidResponse(w http.ResponseWriter) {
func TestProviderChat_AzureURLConstruction(t *testing.T) {
var capturedPath string
var capturedAPIVersion string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedPath = r.URL.Path
capturedAPIVersion = r.URL.Query().Get("api-version")
writeValidResponse(w)
}))
defer server.Close()
@@ -39,22 +51,19 @@ func TestProviderChat_AzureURLConstruction(t *testing.T) {
t.Fatalf("Chat() error = %v", err)
}
wantPath := "/openai/deployments/my-gpt5-deployment/chat/completions"
wantPath := "/openai/v1/responses"
if capturedPath != wantPath {
t.Errorf("URL path = %q, want %q", capturedPath, wantPath)
}
if capturedAPIVersion != azureAPIVersion {
t.Errorf("api-version = %q, want %q", capturedAPIVersion, azureAPIVersion)
}
}
func TestProviderChat_AzureAuthHeader(t *testing.T) {
var capturedAPIKey string
var capturedAuth string
var capturedAPIKey string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedAPIKey = r.Header.Get("Api-Key")
capturedAuth = r.Header.Get("Authorization")
capturedAPIKey = r.Header.Get("Api-Key")
writeValidResponse(w)
}))
defer server.Close()
@@ -65,15 +74,15 @@ func TestProviderChat_AzureAuthHeader(t *testing.T) {
t.Fatalf("Chat() error = %v", err)
}
if capturedAPIKey != "test-azure-key" {
t.Errorf("api-key header = %q, want %q", capturedAPIKey, "test-azure-key")
if capturedAuth != "Bearer test-azure-key" {
t.Errorf("Authorization header = %q, want %q", capturedAuth, "Bearer test-azure-key")
}
if capturedAuth != "" {
t.Errorf("Authorization header should be empty, got %q", capturedAuth)
if capturedAPIKey != "" {
t.Errorf("Api-Key header should be empty, got %q", capturedAPIKey)
}
}
func TestProviderChat_AzureOmitsModelFromBody(t *testing.T) {
func TestProviderChat_AzureRequestBodyContainsModel(t *testing.T) {
var requestBody map[string]any
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -83,17 +92,17 @@ func TestProviderChat_AzureOmitsModelFromBody(t *testing.T) {
defer server.Close()
p := NewProvider("test-key", server.URL, "")
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "deployment", nil)
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "my-deployment", nil)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
if _, exists := requestBody["model"]; exists {
t.Error("request body should not contain 'model' field for Azure OpenAI")
if requestBody["model"] != "my-deployment" {
t.Errorf("model = %v, want %q", requestBody["model"], "my-deployment")
}
}
func TestProviderChat_AzureUsesMaxCompletionTokens(t *testing.T) {
func TestProviderChat_AzureUsesMaxOutputTokens(t *testing.T) {
var requestBody map[string]any
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -114,12 +123,35 @@ func TestProviderChat_AzureUsesMaxCompletionTokens(t *testing.T) {
t.Fatalf("Chat() error = %v", err)
}
if _, exists := requestBody["max_completion_tokens"]; !exists {
t.Error("request body should contain 'max_completion_tokens'")
if requestBody["max_output_tokens"] == nil {
t.Error("request body should contain 'max_output_tokens'")
}
if _, exists := requestBody["max_tokens"]; exists {
t.Error("request body should not contain 'max_tokens'")
}
if _, exists := requestBody["max_completion_tokens"]; exists {
t.Error("request body should not contain 'max_completion_tokens'")
}
}
func TestProviderChat_AzureStoreIsFalse(t *testing.T) {
var requestBody map[string]any
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
json.NewDecoder(r.Body).Decode(&requestBody)
writeValidResponse(w)
}))
defer server.Close()
p := NewProvider("test-key", server.URL, "")
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "deployment", nil)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
if requestBody["store"] != false {
t.Errorf("store = %v, want false", requestBody["store"])
}
}
func TestProviderChat_AzureHTTPError(t *testing.T) {
@@ -135,27 +167,66 @@ func TestProviderChat_AzureHTTPError(t *testing.T) {
}
}
func TestProviderChat_AzureParseTextOutput(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
resp := map[string]any{
"id": "resp_1",
"object": "response",
"status": "completed",
"output": []map[string]any{
{
"type": "message",
"content": []map[string]any{
{"type": "output_text", "text": "Hello there!"},
},
},
},
"usage": map[string]any{
"input_tokens": 10, "output_tokens": 5, "total_tokens": 15,
"input_tokens_details": map[string]any{"cached_tokens": 0},
"output_tokens_details": map[string]any{"reasoning_tokens": 0},
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
p := NewProvider("test-key", server.URL, "")
out, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "deployment", nil)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
if out.Content != "Hello there!" {
t.Errorf("Content = %q, want %q", out.Content, "Hello there!")
}
if out.FinishReason != "stop" {
t.Errorf("FinishReason = %q, want %q", out.FinishReason, "stop")
}
if out.Usage.TotalTokens != 15 {
t.Errorf("TotalTokens = %d, want 15", out.Usage.TotalTokens)
}
}
func TestProviderChat_AzureParseToolCalls(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
resp := map[string]any{
"choices": []map[string]any{
"id": "resp_2",
"object": "response",
"status": "completed",
"output": []map[string]any{
{
"message": map[string]any{
"content": "",
"tool_calls": []map[string]any{
{
"id": "call_1",
"type": "function",
"function": map[string]any{
"name": "get_weather",
"arguments": `{"city":"Seattle"}`,
},
},
},
},
"finish_reason": "tool_calls",
"type": "function_call",
"call_id": "call_1",
"name": "get_weather",
"arguments": `{"city":"Seattle"}`,
},
},
"usage": map[string]any{
"input_tokens": 10, "output_tokens": 8, "total_tokens": 18,
"input_tokens_details": map[string]any{"cached_tokens": 0},
"output_tokens_details": map[string]any{"reasoning_tokens": 0},
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
@@ -167,13 +238,15 @@ func TestProviderChat_AzureParseToolCalls(t *testing.T) {
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
if len(out.ToolCalls) != 1 {
t.Fatalf("len(ToolCalls) = %d, want 1", len(out.ToolCalls))
}
if out.ToolCalls[0].Name != "get_weather" {
t.Errorf("ToolCalls[0].Name = %q, want %q", out.ToolCalls[0].Name, "get_weather")
}
if out.FinishReason != "tool_calls" {
t.Errorf("FinishReason = %q, want %q", out.FinishReason, "tool_calls")
}
}
func TestProvider_AzureEmptyAPIBase(t *testing.T) {
@@ -205,28 +278,103 @@ func TestProvider_AzureNewProviderWithTimeout(t *testing.T) {
}
}
func TestProviderChat_AzureDeploymentNameEscaped(t *testing.T) {
var capturedPath string
func TestProviderChat_AzureNativeWebSearchInjection(t *testing.T) {
var requestBody map[string]any
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedPath = r.URL.RawPath // use RawPath to see percent-encoding
if capturedPath == "" {
capturedPath = r.URL.Path
}
json.NewDecoder(r.Body).Decode(&requestBody)
writeValidResponse(w)
}))
defer server.Close()
tools := []ToolDefinition{
{
Type: "function",
Function: protocoltypes.ToolFunctionDefinition{
Name: "web_search",
Description: "local web search",
Parameters: map[string]any{"type": "object"},
},
},
{
Type: "function",
Function: protocoltypes.ToolFunctionDefinition{
Name: "read_file",
Description: "read a file",
Parameters: map[string]any{"type": "object"},
},
},
}
p := NewProvider("test-key", server.URL, "")
// Deployment name with characters that could cause path injection
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "my deploy/../../admin", nil)
// With native_search=true: user-defined web_search should be replaced by built-in
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, tools, "deployment",
map[string]any{"native_search": true})
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
// The slash and special chars in the deployment name must be escaped, not treated as path separators
if capturedPath == "/openai/deployments/my deploy/../../admin/chat/completions" {
t.Fatal("deployment name was interpolated without escaping — path injection possible")
toolsAny, ok := requestBody["tools"].([]any)
if !ok {
t.Fatal("request body should contain 'tools' array")
}
if len(toolsAny) != 2 {
t.Fatalf("len(tools) = %d, want 2 (read_file + web_search builtin)", len(toolsAny))
}
// First tool should be read_file (user-defined web_search was skipped)
firstTool, _ := toolsAny[0].(map[string]any)
if firstTool["name"] != "read_file" {
t.Errorf("first tool name = %v, want %q", firstTool["name"], "read_file")
}
// Second tool should be built-in web_search
secondTool, _ := toolsAny[1].(map[string]any)
if secondTool["type"] != "web_search" {
t.Errorf("second tool type = %v, want %q", secondTool["type"], "web_search")
}
}
func TestProviderChat_AzureNoNativeWebSearch(t *testing.T) {
var requestBody map[string]any
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
json.NewDecoder(r.Body).Decode(&requestBody)
writeValidResponse(w)
}))
defer server.Close()
tools := []ToolDefinition{
{
Type: "function",
Function: protocoltypes.ToolFunctionDefinition{
Name: "web_search",
Description: "local web search",
Parameters: map[string]any{"type": "object"},
},
},
}
p := NewProvider("test-key", server.URL, "")
// Without native_search: user-defined web_search should be kept as-is
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, tools, "deployment", nil)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
toolsAny, ok := requestBody["tools"].([]any)
if !ok {
t.Fatal("request body should contain 'tools' array")
}
if len(toolsAny) != 1 {
t.Fatalf("len(tools) = %d, want 1", len(toolsAny))
}
// Should be the user-defined function tool, not built-in
tool, _ := toolsAny[0].(map[string]any)
if tool["type"] != "function" {
t.Errorf("tool type = %v, want %q", tool["type"], "function")
}
}
+6 -184
View File
@@ -2,7 +2,6 @@ package providers
import (
"context"
"encoding/json"
"errors"
"fmt"
"strings"
@@ -13,6 +12,7 @@ import (
"github.com/sipeed/picoclaw/pkg/auth"
"github.com/sipeed/picoclaw/pkg/logger"
orc "github.com/sipeed/picoclaw/pkg/providers/openai_responses_common"
)
const (
@@ -96,7 +96,7 @@ func (p *CodexProvider) Chat(
}
// Respect tools.web.prefer_native: only inject native search when the agent
// loop requested it (options["native_search"]), so prefer_native: false
// loop passes options["native_search"]=true, so prefer_native=false means no injection.
useNativeSearch := p.enableWebSearch && (options["native_search"] == true)
params := buildCodexParams(messages, tools, resolvedModel, options, useNativeSearch)
@@ -153,7 +153,7 @@ func (p *CodexProvider) Chat(
return nil, fmt.Errorf("codex API call: stream ended without completed response")
}
return parseCodexResponse(resp), nil
return orc.ParseResponseFromStruct(resp), nil
}
func (p *CodexProvider) GetDefaultModel() string {
@@ -209,89 +209,14 @@ func resolveCodexModel(model string) (string, string) {
func buildCodexParams(
messages []Message, tools []ToolDefinition, model string, options map[string]any, enableWebSearch bool,
) responses.ResponseNewParams {
var inputItems responses.ResponseInputParam
var instructions string
for _, msg := range messages {
switch msg.Role {
case "system":
// Use the full concatenated system prompt (static + dynamic + summary)
// as instructions. This keeps behavior consistent with Anthropic and
// OpenAI-compat adapters where the complete system context lives in
// one place. Prefix caching is handled by prompt_cache_key below,
// not by splitting content across instructions vs input messages.
instructions = msg.Content
case "user":
if msg.ToolCallID != "" {
inputItems = append(inputItems, responses.ResponseInputItemUnionParam{
OfFunctionCallOutput: &responses.ResponseInputItemFunctionCallOutputParam{
CallID: msg.ToolCallID,
Output: responses.ResponseInputItemFunctionCallOutputOutputUnionParam{
OfString: openai.Opt(msg.Content),
},
},
})
} else {
inputItems = append(inputItems, responses.ResponseInputItemUnionParam{
OfMessage: &responses.EasyInputMessageParam{
Role: responses.EasyInputMessageRoleUser,
Content: responses.EasyInputMessageContentUnionParam{OfString: openai.Opt(msg.Content)},
},
})
}
case "assistant":
if len(msg.ToolCalls) > 0 {
if msg.Content != "" {
inputItems = append(inputItems, responses.ResponseInputItemUnionParam{
OfMessage: &responses.EasyInputMessageParam{
Role: responses.EasyInputMessageRoleAssistant,
Content: responses.EasyInputMessageContentUnionParam{OfString: openai.Opt(msg.Content)},
},
})
}
for _, tc := range msg.ToolCalls {
name, args, ok := resolveCodexToolCall(tc)
if !ok {
logger.WarnCF("provider.codex", "Skipping invalid tool call in history", map[string]any{
"call_id": tc.ID,
})
continue
}
inputItems = append(inputItems, responses.ResponseInputItemUnionParam{
OfFunctionCall: &responses.ResponseFunctionToolCallParam{
CallID: tc.ID,
Name: name,
Arguments: args,
},
})
}
} else {
inputItems = append(inputItems, responses.ResponseInputItemUnionParam{
OfMessage: &responses.EasyInputMessageParam{
Role: responses.EasyInputMessageRoleAssistant,
Content: responses.EasyInputMessageContentUnionParam{OfString: openai.Opt(msg.Content)},
},
})
}
case "tool":
inputItems = append(inputItems, responses.ResponseInputItemUnionParam{
OfFunctionCallOutput: &responses.ResponseInputItemFunctionCallOutputParam{
CallID: msg.ToolCallID,
Output: responses.ResponseInputItemFunctionCallOutputOutputUnionParam{
OfString: openai.Opt(msg.Content),
},
},
})
}
}
inputItems, instructions := orc.TranslateMessages(messages)
params := responses.ResponseNewParams{
Model: model,
Input: responses.ResponseNewParamsInputUnion{
OfInputItemList: inputItems,
},
Instructions: openai.Opt(instructions),
Store: openai.Opt(false),
Store: openai.Opt(false),
}
if instructions != "" {
@@ -309,115 +234,12 @@ func buildCodexParams(
}
if len(tools) > 0 || enableWebSearch {
params.Tools = translateToolsForCodex(tools, enableWebSearch)
params.Tools = orc.TranslateTools(tools, enableWebSearch)
}
return params
}
func resolveCodexToolCall(tc ToolCall) (name string, arguments string, ok bool) {
name = tc.Name
if name == "" && tc.Function != nil {
name = tc.Function.Name
}
if name == "" {
return "", "", false
}
if len(tc.Arguments) > 0 {
argsJSON, err := json.Marshal(tc.Arguments)
if err != nil {
return "", "", false
}
return name, string(argsJSON), true
}
if tc.Function != nil && tc.Function.Arguments != "" {
return name, tc.Function.Arguments, true
}
return name, "{}", true
}
func translateToolsForCodex(tools []ToolDefinition, enableWebSearch bool) []responses.ToolUnionParam {
capHint := len(tools)
if enableWebSearch {
capHint++
}
result := make([]responses.ToolUnionParam, 0, capHint)
for _, t := range tools {
if t.Type != "function" {
continue
}
if enableWebSearch && strings.EqualFold(t.Function.Name, "web_search") {
continue
}
ft := responses.FunctionToolParam{
Name: t.Function.Name,
Parameters: t.Function.Parameters,
Strict: openai.Opt(false),
}
if t.Function.Description != "" {
ft.Description = openai.Opt(t.Function.Description)
}
result = append(result, responses.ToolUnionParam{OfFunction: &ft})
}
if enableWebSearch {
result = append(result, responses.ToolParamOfWebSearch(responses.WebSearchToolTypeWebSearch))
}
return result
}
func parseCodexResponse(resp *responses.Response) *LLMResponse {
var content strings.Builder
var toolCalls []ToolCall
for _, item := range resp.Output {
switch item.Type {
case "message":
for _, c := range item.Content {
if c.Type == "output_text" {
content.WriteString(c.Text)
}
}
case "function_call":
var args map[string]any
if err := json.Unmarshal([]byte(item.Arguments), &args); err != nil {
args = map[string]any{"raw": item.Arguments}
}
toolCalls = append(toolCalls, ToolCall{
ID: item.CallID,
Name: item.Name,
Arguments: args,
})
}
}
finishReason := "stop"
if len(toolCalls) > 0 {
finishReason = "tool_calls"
}
if resp.Status == "incomplete" {
finishReason = "length"
}
var usage *UsageInfo
if resp.Usage.TotalTokens > 0 {
usage = &UsageInfo{
PromptTokens: int(resp.Usage.InputTokens),
CompletionTokens: int(resp.Usage.OutputTokens),
TotalTokens: int(resp.Usage.TotalTokens),
}
}
return &LLMResponse{
Content: content.String(),
ToolCalls: toolCalls,
FinishReason: finishReason,
Usage: usage,
}
}
func createCodexTokenSource() func() (string, string, error) {
return func() (string, string, error) {
cred, err := auth.GetCredential("openai")
+4 -2
View File
@@ -10,6 +10,8 @@ import (
"github.com/openai/openai-go/v3"
openaiopt "github.com/openai/openai-go/v3/option"
"github.com/openai/openai-go/v3/responses"
orc "github.com/sipeed/picoclaw/pkg/providers/openai_responses_common"
)
func TestBuildCodexParams_BasicMessage(t *testing.T) {
@@ -225,7 +227,7 @@ func TestParseCodexResponse_TextOutput(t *testing.T) {
t.Fatalf("unmarshal: %v", err)
}
result := parseCodexResponse(&resp)
result := orc.ParseResponseFromStruct(&resp)
if result.Content != "Hello there!" {
t.Errorf("Content = %q, want %q", result.Content, "Hello there!")
}
@@ -266,7 +268,7 @@ func TestParseCodexResponse_FunctionCall(t *testing.T) {
t.Fatalf("unmarshal: %v", err)
}
result := parseCodexResponse(&resp)
result := orc.ParseResponseFromStruct(&resp)
if len(result.ToolCalls) != 1 {
t.Fatalf("len(ToolCalls) = %d, want 1", len(result.ToolCalls))
}
@@ -0,0 +1,291 @@
// Package openai_responses_common provides shared utilities for providers
// that use the OpenAI Responses API (e.g., Azure, Codex).
package openai_responses_common
import (
"encoding/json"
"io"
"strings"
"github.com/openai/openai-go/v3"
"github.com/openai/openai-go/v3/responses"
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
)
// TranslateMessages converts internal Message entries to the OpenAI Responses API
// input format. System messages are extracted as instructions (returned separately),
// user/assistant/tool messages become ResponseInputItemUnionParam entries.
// Supports multipart media (images, audio).
func TranslateMessages(messages []protocoltypes.Message) (input responses.ResponseInputParam, instructions string) {
input = make(responses.ResponseInputParam, 0, len(messages))
for _, msg := range messages {
switch msg.Role {
case "system":
instructions = msg.Content
case "user":
if msg.ToolCallID != "" {
input = append(input, responses.ResponseInputItemUnionParam{
OfFunctionCallOutput: &responses.ResponseInputItemFunctionCallOutputParam{
CallID: msg.ToolCallID,
Output: responses.ResponseInputItemFunctionCallOutputOutputUnionParam{
OfString: openai.Opt(msg.Content),
},
},
})
} else if len(msg.Media) > 0 {
content := BuildMultipartContent(msg.Content, msg.Media)
input = append(input, responses.ResponseInputItemUnionParam{
OfInputMessage: &responses.ResponseInputItemMessageParam{
Role: "user",
Content: content,
},
})
} else {
input = append(input, responses.ResponseInputItemUnionParam{
OfMessage: &responses.EasyInputMessageParam{
Role: responses.EasyInputMessageRoleUser,
Content: responses.EasyInputMessageContentUnionParam{OfString: openai.Opt(msg.Content)},
},
})
}
case "assistant":
if len(msg.ToolCalls) > 0 {
if msg.Content != "" {
input = append(input, responses.ResponseInputItemUnionParam{
OfMessage: &responses.EasyInputMessageParam{
Role: responses.EasyInputMessageRoleAssistant,
Content: responses.EasyInputMessageContentUnionParam{OfString: openai.Opt(msg.Content)},
},
})
}
for _, tc := range msg.ToolCalls {
name, args, ok := ResolveToolCall(tc)
if !ok {
continue
}
input = append(input, responses.ResponseInputItemUnionParam{
OfFunctionCall: &responses.ResponseFunctionToolCallParam{
CallID: tc.ID,
Name: name,
Arguments: args,
},
})
}
} else {
input = append(input, responses.ResponseInputItemUnionParam{
OfMessage: &responses.EasyInputMessageParam{
Role: responses.EasyInputMessageRoleAssistant,
Content: responses.EasyInputMessageContentUnionParam{OfString: openai.Opt(msg.Content)},
},
})
}
case "tool":
input = append(input, responses.ResponseInputItemUnionParam{
OfFunctionCallOutput: &responses.ResponseInputItemFunctionCallOutputParam{
CallID: msg.ToolCallID,
Output: responses.ResponseInputItemFunctionCallOutputOutputUnionParam{
OfString: openai.Opt(msg.Content),
},
},
})
}
}
return input, instructions
}
// BuildMultipartContent constructs a ResponseInputMessageContentListParam from
// text content and media URLs (data:image/... and data:audio/... URIs).
func BuildMultipartContent(text string, media []string) responses.ResponseInputMessageContentListParam {
parts := make(responses.ResponseInputMessageContentListParam, 0, 1+len(media))
if text != "" {
parts = append(parts, responses.ResponseInputContentUnionParam{
OfInputText: &responses.ResponseInputTextParam{
Text: text,
},
})
}
for _, mediaURL := range media {
if strings.HasPrefix(mediaURL, "data:image/") {
parts = append(parts, responses.ResponseInputContentUnionParam{
OfInputImage: &responses.ResponseInputImageParam{
ImageURL: openai.Opt(mediaURL),
Detail: responses.ResponseInputImageDetailAuto,
},
})
} else if strings.HasPrefix(mediaURL, "data:audio/") {
if format, data, ok := ParseDataAudioURL(mediaURL); ok {
parts = append(parts, responses.ResponseInputContentUnionParam{
OfInputFile: &responses.ResponseInputFileParam{
FileData: openai.Opt(data),
Filename: openai.Opt("audio." + format),
},
})
}
}
}
return parts
}
// ParseDataAudioURL extracts the format and base64 data from a data:audio/... URL.
func ParseDataAudioURL(mediaURL string) (format, data string, ok bool) {
if !strings.HasPrefix(mediaURL, "data:audio/") {
return "", "", false
}
payload := strings.TrimPrefix(mediaURL, "data:audio/")
meta, data, found := strings.Cut(payload, ",")
if !found {
return "", "", false
}
format, _, _ = strings.Cut(meta, ";")
format = strings.TrimSpace(format)
data = strings.TrimSpace(data)
if format == "" || data == "" {
return "", "", false
}
return format, data, true
}
// ResolveToolCall extracts the function name and JSON arguments string from a ToolCall.
// Returns ok=false if the tool call has no name or if arguments fail to marshal.
func ResolveToolCall(tc protocoltypes.ToolCall) (name string, arguments string, ok bool) {
name = tc.Name
if name == "" && tc.Function != nil {
name = tc.Function.Name
}
if name == "" {
return "", "", false
}
if len(tc.Arguments) > 0 {
argsJSON, err := json.Marshal(tc.Arguments)
if err != nil {
return "", "", false
}
return name, string(argsJSON), true
}
if tc.Function != nil && tc.Function.Arguments != "" {
return name, tc.Function.Arguments, true
}
return name, "{}", true
}
// TranslateTools converts internal ToolDefinition entries to the OpenAI Responses API
// tool format. If enableWebSearch is true, a web_search tool is appended and any
// user-defined tool named "web_search" is skipped to avoid duplicates.
func TranslateTools(tools []protocoltypes.ToolDefinition, enableWebSearch bool) []responses.ToolUnionParam {
capHint := len(tools)
if enableWebSearch {
capHint++
}
result := make([]responses.ToolUnionParam, 0, capHint)
for _, t := range tools {
if t.Type != "function" {
continue
}
if enableWebSearch && strings.EqualFold(t.Function.Name, "web_search") {
continue
}
ft := responses.FunctionToolParam{
Name: t.Function.Name,
Parameters: t.Function.Parameters,
Strict: openai.Opt(false),
}
if t.Function.Description != "" {
ft.Description = openai.Opt(t.Function.Description)
}
result = append(result, responses.ToolUnionParam{OfFunction: &ft})
}
if enableWebSearch {
result = append(result, responses.ToolParamOfWebSearch(responses.WebSearchToolTypeWebSearch))
}
return result
}
// ParseResponseBody parses an OpenAI Responses API JSON body into an LLMResponse.
// Handles output item types: "message" (output_text + refusal), "function_call", and "reasoning".
func ParseResponseBody(body io.Reader) (*protocoltypes.LLMResponse, error) {
var apiResp responses.Response
if err := json.NewDecoder(body).Decode(&apiResp); err != nil {
return nil, err
}
return parseResponse(&apiResp), nil
}
// ParseResponseFromStruct converts a decoded responses.Response into an LLMResponse.
// Used by providers that receive the Response struct directly (e.g., via streaming SDK).
func ParseResponseFromStruct(resp *responses.Response) *protocoltypes.LLMResponse {
return parseResponse(resp)
}
// parseResponse is the shared implementation for extracting LLMResponse fields
// from a decoded responses.Response.
func parseResponse(apiResp *responses.Response) *protocoltypes.LLMResponse {
var content strings.Builder
var reasoningContent strings.Builder
var toolCalls []protocoltypes.ToolCall
for _, item := range apiResp.Output {
switch item.Type {
case "message":
for _, c := range item.Content {
switch c.Type {
case "output_text":
content.WriteString(c.Text)
case "refusal":
content.WriteString(c.Refusal)
}
}
case "function_call":
var args map[string]any
if err := json.Unmarshal([]byte(item.Arguments), &args); err != nil {
args = map[string]any{"raw": item.Arguments}
}
toolCalls = append(toolCalls, protocoltypes.ToolCall{
ID: item.CallID,
Name: item.Name,
Arguments: args,
})
case "reasoning":
for _, s := range item.Summary {
reasoningContent.WriteString(s.Text)
}
}
}
finishReason := "stop"
if len(toolCalls) > 0 {
finishReason = "tool_calls"
}
if apiResp.Status == "incomplete" {
finishReason = "length"
}
var usage *protocoltypes.UsageInfo
if apiResp.Usage.TotalTokens > 0 {
usage = &protocoltypes.UsageInfo{
PromptTokens: int(apiResp.Usage.InputTokens),
CompletionTokens: int(apiResp.Usage.OutputTokens),
TotalTokens: int(apiResp.Usage.TotalTokens),
}
}
return &protocoltypes.LLMResponse{
Content: content.String(),
ReasoningContent: reasoningContent.String(),
ToolCalls: toolCalls,
FinishReason: finishReason,
Usage: usage,
}
}
@@ -0,0 +1,593 @@
package openai_responses_common
import (
"encoding/json"
"strings"
"testing"
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
)
// --- TranslateMessages tests ---
func TestTranslateMessages_SystemExtractedAsInstructions(t *testing.T) {
msgs := []protocoltypes.Message{
{Role: "system", Content: "You are helpful"},
{Role: "user", Content: "Hi"},
}
input, instructions := TranslateMessages(msgs)
if instructions != "You are helpful" {
t.Errorf("instructions = %q, want %q", instructions, "You are helpful")
}
if len(input) != 1 {
t.Fatalf("len(input) = %d, want 1", len(input))
}
if input[0].OfMessage == nil {
t.Fatal("expected user message item")
}
}
func TestTranslateMessages_UserTextMessage(t *testing.T) {
msgs := []protocoltypes.Message{
{Role: "user", Content: "Hello"},
}
input, instructions := TranslateMessages(msgs)
if instructions != "" {
t.Errorf("instructions = %q, want empty", instructions)
}
if len(input) != 1 {
t.Fatalf("len(input) = %d, want 1", len(input))
}
if input[0].OfMessage == nil {
t.Fatal("expected EasyInputMessage")
}
if string(input[0].OfMessage.Role) != "user" {
t.Errorf("role = %q, want %q", input[0].OfMessage.Role, "user")
}
}
func TestTranslateMessages_UserWithToolCallID(t *testing.T) {
msgs := []protocoltypes.Message{
{Role: "user", Content: `{"temp":72}`, ToolCallID: "call_1"},
}
input, _ := TranslateMessages(msgs)
if len(input) != 1 {
t.Fatalf("len(input) = %d, want 1", len(input))
}
if input[0].OfFunctionCallOutput == nil {
t.Fatal("expected FunctionCallOutput for user with ToolCallID")
}
if input[0].OfFunctionCallOutput.CallID != "call_1" {
t.Errorf("CallID = %q, want %q", input[0].OfFunctionCallOutput.CallID, "call_1")
}
}
func TestTranslateMessages_UserWithMedia(t *testing.T) {
msgs := []protocoltypes.Message{
{Role: "user", Content: "Describe this", Media: []string{"data:image/png;base64,abc123"}},
}
input, _ := TranslateMessages(msgs)
if len(input) != 1 {
t.Fatalf("len(input) = %d, want 1", len(input))
}
if input[0].OfInputMessage == nil {
t.Fatal("expected InputMessage for multipart content")
}
if input[0].OfInputMessage.Role != "user" {
t.Errorf("role = %q, want %q", input[0].OfInputMessage.Role, "user")
}
}
func TestTranslateMessages_AssistantWithToolCalls(t *testing.T) {
msgs := []protocoltypes.Message{
{Role: "user", Content: "Weather?"},
{
Role: "assistant",
Content: "Let me check",
ToolCalls: []protocoltypes.ToolCall{
{ID: "call_1", Name: "get_weather", Arguments: map[string]any{"city": "SF"}},
},
},
{Role: "tool", Content: `{"temp":72}`, ToolCallID: "call_1"},
}
input, _ := TranslateMessages(msgs)
// user + assistant text + function_call + tool output = 4 items
if len(input) != 4 {
t.Fatalf("len(input) = %d, want 4", len(input))
}
// item[1] = assistant text
if input[1].OfMessage == nil {
t.Fatal("expected assistant text message")
}
// item[2] = function call
if input[2].OfFunctionCall == nil {
t.Fatal("expected function call")
}
if input[2].OfFunctionCall.Name != "get_weather" {
t.Errorf("function name = %q, want %q", input[2].OfFunctionCall.Name, "get_weather")
}
// item[3] = tool output
if input[3].OfFunctionCallOutput == nil {
t.Fatal("expected function call output")
}
}
func TestTranslateMessages_AssistantWithoutToolCalls(t *testing.T) {
msgs := []protocoltypes.Message{
{Role: "assistant", Content: "Sure thing"},
}
input, _ := TranslateMessages(msgs)
if len(input) != 1 {
t.Fatalf("len(input) = %d, want 1", len(input))
}
if input[0].OfMessage == nil {
t.Fatal("expected EasyInputMessage for assistant without tool calls")
}
}
func TestTranslateMessages_ToolMessage(t *testing.T) {
msgs := []protocoltypes.Message{
{Role: "tool", Content: "result data", ToolCallID: "call_99"},
}
input, _ := TranslateMessages(msgs)
if len(input) != 1 {
t.Fatalf("len(input) = %d, want 1", len(input))
}
if input[0].OfFunctionCallOutput == nil {
t.Fatal("expected FunctionCallOutput")
}
if input[0].OfFunctionCallOutput.CallID != "call_99" {
t.Errorf("CallID = %q, want %q", input[0].OfFunctionCallOutput.CallID, "call_99")
}
}
// --- ResolveToolCall tests ---
func TestResolveToolCall_FromNameAndArguments(t *testing.T) {
tc := protocoltypes.ToolCall{
Name: "get_weather",
Arguments: map[string]any{"city": "SF"},
}
name, args, ok := ResolveToolCall(tc)
if !ok {
t.Fatal("expected ok=true")
}
if name != "get_weather" {
t.Errorf("name = %q, want %q", name, "get_weather")
}
if !strings.Contains(args, "SF") {
t.Errorf("args = %q, want to contain SF", args)
}
}
func TestResolveToolCall_FromFunctionField(t *testing.T) {
tc := protocoltypes.ToolCall{
ID: "call_1",
Function: &protocoltypes.FunctionCall{
Name: "read_file",
Arguments: `{"path":"README.md"}`,
},
}
name, args, ok := ResolveToolCall(tc)
if !ok {
t.Fatal("expected ok=true")
}
if name != "read_file" {
t.Errorf("name = %q, want %q", name, "read_file")
}
if args != `{"path":"README.md"}` {
t.Errorf("args = %q, want %q", args, `{"path":"README.md"}`)
}
}
func TestResolveToolCall_EmptyName(t *testing.T) {
tc := protocoltypes.ToolCall{}
_, _, ok := ResolveToolCall(tc)
if ok {
t.Error("expected ok=false for empty tool call")
}
}
func TestResolveToolCall_NoArgsFallsBackToEmptyObject(t *testing.T) {
tc := protocoltypes.ToolCall{Name: "do_something"}
name, args, ok := ResolveToolCall(tc)
if !ok {
t.Fatal("expected ok=true")
}
if name != "do_something" {
t.Errorf("name = %q, want %q", name, "do_something")
}
if args != "{}" {
t.Errorf("args = %q, want %q", args, "{}")
}
}
// --- TranslateTools tests ---
func TestTranslateTools_FunctionTools(t *testing.T) {
tools := []protocoltypes.ToolDefinition{
{
Type: "function",
Function: protocoltypes.ToolFunctionDefinition{
Name: "get_weather",
Description: "Get weather",
Parameters: map[string]any{"type": "object"},
},
},
}
result := TranslateTools(tools, false)
if len(result) != 1 {
t.Fatalf("len(result) = %d, want 1", len(result))
}
if result[0].OfFunction == nil {
t.Fatal("expected function tool")
}
if result[0].OfFunction.Name != "get_weather" {
t.Errorf("name = %q, want %q", result[0].OfFunction.Name, "get_weather")
}
}
func TestTranslateTools_SkipsNonFunction(t *testing.T) {
tools := []protocoltypes.ToolDefinition{
{Type: "not_function"},
}
result := TranslateTools(tools, false)
if len(result) != 0 {
t.Errorf("len(result) = %d, want 0", len(result))
}
}
func TestTranslateTools_WebSearchAppended(t *testing.T) {
result := TranslateTools(nil, true)
if len(result) != 1 {
t.Fatalf("len(result) = %d, want 1", len(result))
}
if result[0].OfWebSearch == nil {
t.Fatal("expected web_search tool")
}
}
func TestTranslateTools_WebSearchReplacesUserDefined(t *testing.T) {
tools := []protocoltypes.ToolDefinition{
{
Type: "function",
Function: protocoltypes.ToolFunctionDefinition{
Name: "web_search",
Parameters: map[string]any{"type": "object"},
},
},
{
Type: "function",
Function: protocoltypes.ToolFunctionDefinition{
Name: "read_file",
Parameters: map[string]any{"type": "object"},
},
},
}
result := TranslateTools(tools, true)
if len(result) != 2 {
t.Fatalf("len(result) = %d, want 2", len(result))
}
if result[0].OfFunction == nil || result[0].OfFunction.Name != "read_file" {
t.Errorf("first tool should be read_file, got %v", result[0])
}
if result[1].OfWebSearch == nil {
t.Error("second tool should be web_search")
}
}
func TestTranslateTools_DescriptionOmittedWhenEmpty(t *testing.T) {
tools := []protocoltypes.ToolDefinition{
{
Type: "function",
Function: protocoltypes.ToolFunctionDefinition{
Name: "no_desc",
Parameters: map[string]any{"type": "object"},
},
},
}
result := TranslateTools(tools, false)
if len(result) != 1 {
t.Fatalf("len(result) = %d, want 1", len(result))
}
if result[0].OfFunction.Description.Valid() {
t.Error("Description should not be set when empty")
}
}
// --- ParseResponseBody tests ---
func TestParseResponseBody_TextOutput(t *testing.T) {
body := strings.NewReader(`{
"id": "resp_123",
"object": "response",
"status": "completed",
"output": [
{
"type": "message",
"content": [{"type": "output_text", "text": "Hello!"}]
}
],
"usage": {
"input_tokens": 10,
"output_tokens": 5,
"total_tokens": 15,
"input_tokens_details": {"cached_tokens": 0},
"output_tokens_details": {"reasoning_tokens": 0}
}
}`)
result, err := ParseResponseBody(body)
if err != nil {
t.Fatalf("ParseResponseBody error: %v", err)
}
if result.Content != "Hello!" {
t.Errorf("Content = %q, want %q", result.Content, "Hello!")
}
if result.FinishReason != "stop" {
t.Errorf("FinishReason = %q, want %q", result.FinishReason, "stop")
}
if result.Usage.TotalTokens != 15 {
t.Errorf("TotalTokens = %d, want 15", result.Usage.TotalTokens)
}
}
func TestParseResponseBody_FunctionCall(t *testing.T) {
body := strings.NewReader(`{
"id": "resp_456",
"object": "response",
"status": "completed",
"output": [
{
"type": "function_call",
"call_id": "call_abc",
"name": "get_weather",
"arguments": "{\"city\":\"SF\"}"
}
],
"usage": {
"input_tokens": 10,
"output_tokens": 8,
"total_tokens": 18,
"input_tokens_details": {"cached_tokens": 0},
"output_tokens_details": {"reasoning_tokens": 0}
}
}`)
result, err := ParseResponseBody(body)
if err != nil {
t.Fatalf("ParseResponseBody error: %v", err)
}
if len(result.ToolCalls) != 1 {
t.Fatalf("len(ToolCalls) = %d, want 1", len(result.ToolCalls))
}
if result.ToolCalls[0].Name != "get_weather" {
t.Errorf("Name = %q, want %q", result.ToolCalls[0].Name, "get_weather")
}
if result.ToolCalls[0].ID != "call_abc" {
t.Errorf("ID = %q, want %q", result.ToolCalls[0].ID, "call_abc")
}
if result.FinishReason != "tool_calls" {
t.Errorf("FinishReason = %q, want %q", result.FinishReason, "tool_calls")
}
}
func TestParseResponseBody_Reasoning(t *testing.T) {
body := strings.NewReader(`{
"id": "resp_789",
"object": "response",
"status": "completed",
"output": [
{
"type": "reasoning",
"id": "rs_1",
"summary": [{"type": "summary_text", "text": "Thinking about it..."}]
},
{
"type": "message",
"content": [{"type": "output_text", "text": "The answer is 42."}]
}
],
"usage": {
"input_tokens": 10,
"output_tokens": 20,
"total_tokens": 30,
"input_tokens_details": {"cached_tokens": 0},
"output_tokens_details": {"reasoning_tokens": 10}
}
}`)
result, err := ParseResponseBody(body)
if err != nil {
t.Fatalf("ParseResponseBody error: %v", err)
}
if result.Content != "The answer is 42." {
t.Errorf("Content = %q, want %q", result.Content, "The answer is 42.")
}
if result.ReasoningContent != "Thinking about it..." {
t.Errorf("ReasoningContent = %q, want %q", result.ReasoningContent, "Thinking about it...")
}
}
func TestParseResponseBody_Refusal(t *testing.T) {
body := strings.NewReader(`{
"id": "resp_ref",
"object": "response",
"status": "completed",
"output": [
{
"type": "message",
"content": [{"type": "refusal", "refusal": "I cannot help with that."}]
}
],
"usage": {
"input_tokens": 5,
"output_tokens": 5,
"total_tokens": 10,
"input_tokens_details": {"cached_tokens": 0},
"output_tokens_details": {"reasoning_tokens": 0}
}
}`)
result, err := ParseResponseBody(body)
if err != nil {
t.Fatalf("ParseResponseBody error: %v", err)
}
if result.Content != "I cannot help with that." {
t.Errorf("Content = %q, want %q", result.Content, "I cannot help with that.")
}
}
func TestParseResponseBody_IncompleteStatus(t *testing.T) {
body := strings.NewReader(`{
"id": "resp_inc",
"object": "response",
"status": "incomplete",
"output": [
{
"type": "message",
"content": [{"type": "output_text", "text": "partial"}]
}
],
"usage": {"input_tokens": 5, "output_tokens": 2, "total_tokens": 7,
"input_tokens_details": {"cached_tokens": 0},
"output_tokens_details": {"reasoning_tokens": 0}}
}`)
result, err := ParseResponseBody(body)
if err != nil {
t.Fatalf("error: %v", err)
}
if result.FinishReason != "length" {
t.Errorf("FinishReason = %q, want %q", result.FinishReason, "length")
}
}
func TestParseResponseBody_FailedStatus(t *testing.T) {
body := strings.NewReader(`{
"id": "resp_fail",
"object": "response",
"status": "failed",
"output": [],
"usage": {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0,
"input_tokens_details": {"cached_tokens": 0},
"output_tokens_details": {"reasoning_tokens": 0}}
}`)
result, err := ParseResponseBody(body)
if err != nil {
t.Fatalf("error: %v", err)
}
// failed/canceled statuses are not specially mapped; they fall through to "stop"
if result.FinishReason != "stop" {
t.Errorf("FinishReason = %q, want %q", result.FinishReason, "stop")
}
}
// --- ParseDataAudioURL tests ---
func TestParseDataAudioURL_Valid(t *testing.T) {
format, data, ok := ParseDataAudioURL("data:audio/mp3;base64,SGVsbG8=")
if !ok {
t.Fatal("expected ok=true")
}
if format != "mp3" {
t.Errorf("format = %q, want %q", format, "mp3")
}
if data != "SGVsbG8=" {
t.Errorf("data = %q, want %q", data, "SGVsbG8=")
}
}
func TestParseDataAudioURL_NotAudio(t *testing.T) {
_, _, ok := ParseDataAudioURL("data:image/png;base64,abc")
if ok {
t.Error("expected ok=false for non-audio URL")
}
}
func TestParseDataAudioURL_MalformedNoComma(t *testing.T) {
_, _, ok := ParseDataAudioURL("data:audio/mp3;base64")
if ok {
t.Error("expected ok=false for malformed URL")
}
}
func TestParseDataAudioURL_EmptyData(t *testing.T) {
_, _, ok := ParseDataAudioURL("data:audio/mp3;base64,")
if ok {
t.Error("expected ok=false for empty data")
}
}
// --- BuildMultipartContent tests ---
func TestBuildMultipartContent_TextOnly(t *testing.T) {
parts := BuildMultipartContent("hello", nil)
if len(parts) != 1 {
t.Fatalf("len(parts) = %d, want 1", len(parts))
}
if parts[0].OfInputText == nil {
t.Fatal("expected text part")
}
}
func TestBuildMultipartContent_TextAndImage(t *testing.T) {
parts := BuildMultipartContent("describe", []string{"data:image/png;base64,abc"})
if len(parts) != 2 {
t.Fatalf("len(parts) = %d, want 2", len(parts))
}
if parts[0].OfInputText == nil {
t.Error("first part should be text")
}
if parts[1].OfInputImage == nil {
t.Error("second part should be image")
}
}
func TestBuildMultipartContent_AudioFile(t *testing.T) {
parts := BuildMultipartContent("", []string{"data:audio/wav;base64,AAAA"})
if len(parts) != 1 {
t.Fatalf("len(parts) = %d, want 1", len(parts))
}
if parts[0].OfInputFile == nil {
t.Fatal("expected file part for audio")
}
}
func TestBuildMultipartContent_EmptyTextSkipped(t *testing.T) {
parts := BuildMultipartContent("", []string{"data:image/png;base64,abc"})
if len(parts) != 1 {
t.Fatalf("len(parts) = %d, want 1", len(parts))
}
if parts[0].OfInputImage == nil {
t.Error("should only have image part")
}
}
// --- JSON serialization sanity checks ---
func TestTranslateTools_SerializesToJSON(t *testing.T) {
tools := []protocoltypes.ToolDefinition{
{
Type: "function",
Function: protocoltypes.ToolFunctionDefinition{
Name: "test_tool",
Description: "A test",
Parameters: map[string]any{"type": "object"},
},
},
}
result := TranslateTools(tools, true)
data, err := json.Marshal(result)
if err != nil {
t.Fatalf("json.Marshal error: %v", err)
}
s := string(data)
if !strings.Contains(s, "test_tool") {
t.Errorf("JSON should contain test_tool, got: %s", s)
}
if !strings.Contains(s, "web_search") {
t.Errorf("JSON should contain web_search, got: %s", s)
}
}