diff --git a/pkg/providers/azure/provider.go b/pkg/providers/azure/provider.go index e0ddbbde4..9d29a90cd 100644 --- a/pkg/providers/azure/provider.go +++ b/pkg/providers/azure/provider.go @@ -10,7 +10,11 @@ import ( "strings" "time" + "github.com/openai/openai-go/v3" + "github.com/openai/openai-go/v3/responses" + "github.com/sipeed/picoclaw/pkg/providers/common" + orc "github.com/sipeed/picoclaw/pkg/providers/openai_responses_common" "github.com/sipeed/picoclaw/pkg/providers/protocoltypes" ) @@ -21,14 +25,12 @@ type ( ) const ( - // azureAPIVersion is the Azure OpenAI API version used for all requests. - azureAPIVersion = "2024-10-21" defaultRequestTimeout = common.DefaultRequestTimeout ) // Provider implements the LLM provider interface for Azure OpenAI endpoints. -// It handles Azure-specific authentication (api-key header), URL construction -// (deployment-based), and request body formatting (max_completion_tokens, no model field). +// It handles Azure-specific authentication (Bearer token), URL construction +// (Responses API), and request/response formatting. type Provider struct { apiKey string apiBase string @@ -72,8 +74,8 @@ func NewProviderWithTimeout(apiKey, apiBase, proxy string, requestTimeoutSeconds ) } -// Chat sends a chat completion request to the Azure OpenAI endpoint. -// The model parameter is used as the Azure deployment name in the URL. +// Chat sends a request to the Azure OpenAI Responses API endpoint. +// The model parameter is passed in the request body. func (p *Provider) Chat( ctx context.Context, messages []Message, @@ -85,34 +87,43 @@ func (p *Provider) Chat( return nil, fmt.Errorf("Azure API base not configured") } - // model is the deployment name for Azure OpenAI - deployment := model - - // Build Azure-specific URL safely using url.JoinPath and query encoding - // to prevent path traversal or query injection via deployment names. - base, err := url.JoinPath(p.apiBase, "openai/deployments", deployment, "chat/completions") + requestURL, err := url.JoinPath(p.apiBase, "openai/v1/responses") if err != nil { return nil, fmt.Errorf("failed to build Azure request URL: %w", err) } - requestURL := base + "?api-version=" + azureAPIVersion - // Build request body — no "model" field (Azure infers from deployment URL) - requestBody := map[string]any{ - "messages": common.SerializeMessages(messages), + input, instructions := orc.TranslateMessages(messages) + + requestBody := responses.ResponseNewParams{ + Model: model, + Input: responses.ResponseNewParamsInputUnion{ + OfInputItemList: input, + }, + Store: openai.Opt(false), + } + + if instructions != "" { + requestBody.Instructions = openai.Opt(instructions) } if len(tools) > 0 { - requestBody["tools"] = tools - requestBody["tool_choice"] = "auto" + enableWebSearch, _ := options["native_search"].(bool) + requestBody.Tools = orc.TranslateTools(tools, enableWebSearch) + requestBody.ToolChoice = responses.ResponseNewParamsToolChoiceUnion{ + OfToolChoiceMode: openai.Opt(responses.ToolChoiceOptionsAuto), + } } - // Azure OpenAI always uses max_completion_tokens if maxTokens, ok := common.AsInt(options["max_tokens"]); ok { - requestBody["max_completion_tokens"] = maxTokens + requestBody.MaxOutputTokens = openai.Opt(int64(maxTokens)) } if temperature, ok := common.AsFloat(options["temperature"]); ok { - requestBody["temperature"] = temperature + requestBody.Temperature = openai.Opt(temperature) + } + + if cacheKey, ok := options["prompt_cache_key"].(string); ok && cacheKey != "" { + requestBody.PromptCacheKey = openai.Opt(cacheKey) } jsonData, err := json.Marshal(requestBody) @@ -125,10 +136,9 @@ func (p *Provider) Chat( return nil, fmt.Errorf("failed to create request: %w", err) } - // Azure uses api-key header instead of Authorization: Bearer req.Header.Set("Content-Type", "application/json") if p.apiKey != "" { - req.Header.Set("Api-Key", p.apiKey) + req.Header.Set("Authorization", "Bearer "+p.apiKey) } resp, err := p.httpClient.Do(req) @@ -141,7 +151,7 @@ func (p *Provider) Chat( return nil, common.HandleErrorResponse(resp, p.apiBase) } - return common.ReadAndParseResponse(resp, p.apiBase) + return orc.ParseResponseBody(resp.Body) } // GetDefaultModel returns an empty string as Azure deployments are user-configured. diff --git a/pkg/providers/azure/provider_test.go b/pkg/providers/azure/provider_test.go index 531b81296..e57d68057 100644 --- a/pkg/providers/azure/provider_test.go +++ b/pkg/providers/azure/provider_test.go @@ -6,17 +6,31 @@ import ( "net/http/httptest" "testing" "time" + + "github.com/sipeed/picoclaw/pkg/providers/protocoltypes" ) -// writeValidResponse writes a minimal valid Azure OpenAI chat completion response. +// writeValidResponse writes a minimal valid Responses API response. func writeValidResponse(w http.ResponseWriter) { resp := map[string]any{ - "choices": []map[string]any{ + "id": "resp_test", + "object": "response", + "status": "completed", + "output": []map[string]any{ { - "message": map[string]any{"content": "ok"}, - "finish_reason": "stop", + "type": "message", + "content": []map[string]any{ + {"type": "output_text", "text": "ok"}, + }, }, }, + "usage": map[string]any{ + "input_tokens": 5, + "output_tokens": 2, + "total_tokens": 7, + "input_tokens_details": map[string]any{"cached_tokens": 0}, + "output_tokens_details": map[string]any{"reasoning_tokens": 0}, + }, } w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(resp) @@ -24,11 +38,9 @@ func writeValidResponse(w http.ResponseWriter) { func TestProviderChat_AzureURLConstruction(t *testing.T) { var capturedPath string - var capturedAPIVersion string server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { capturedPath = r.URL.Path - capturedAPIVersion = r.URL.Query().Get("api-version") writeValidResponse(w) })) defer server.Close() @@ -39,22 +51,19 @@ func TestProviderChat_AzureURLConstruction(t *testing.T) { t.Fatalf("Chat() error = %v", err) } - wantPath := "/openai/deployments/my-gpt5-deployment/chat/completions" + wantPath := "/openai/v1/responses" if capturedPath != wantPath { t.Errorf("URL path = %q, want %q", capturedPath, wantPath) } - if capturedAPIVersion != azureAPIVersion { - t.Errorf("api-version = %q, want %q", capturedAPIVersion, azureAPIVersion) - } } func TestProviderChat_AzureAuthHeader(t *testing.T) { - var capturedAPIKey string var capturedAuth string + var capturedAPIKey string server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - capturedAPIKey = r.Header.Get("Api-Key") capturedAuth = r.Header.Get("Authorization") + capturedAPIKey = r.Header.Get("Api-Key") writeValidResponse(w) })) defer server.Close() @@ -65,15 +74,15 @@ func TestProviderChat_AzureAuthHeader(t *testing.T) { t.Fatalf("Chat() error = %v", err) } - if capturedAPIKey != "test-azure-key" { - t.Errorf("api-key header = %q, want %q", capturedAPIKey, "test-azure-key") + if capturedAuth != "Bearer test-azure-key" { + t.Errorf("Authorization header = %q, want %q", capturedAuth, "Bearer test-azure-key") } - if capturedAuth != "" { - t.Errorf("Authorization header should be empty, got %q", capturedAuth) + if capturedAPIKey != "" { + t.Errorf("Api-Key header should be empty, got %q", capturedAPIKey) } } -func TestProviderChat_AzureOmitsModelFromBody(t *testing.T) { +func TestProviderChat_AzureRequestBodyContainsModel(t *testing.T) { var requestBody map[string]any server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -83,17 +92,17 @@ func TestProviderChat_AzureOmitsModelFromBody(t *testing.T) { defer server.Close() p := NewProvider("test-key", server.URL, "") - _, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "deployment", nil) + _, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "my-deployment", nil) if err != nil { t.Fatalf("Chat() error = %v", err) } - if _, exists := requestBody["model"]; exists { - t.Error("request body should not contain 'model' field for Azure OpenAI") + if requestBody["model"] != "my-deployment" { + t.Errorf("model = %v, want %q", requestBody["model"], "my-deployment") } } -func TestProviderChat_AzureUsesMaxCompletionTokens(t *testing.T) { +func TestProviderChat_AzureUsesMaxOutputTokens(t *testing.T) { var requestBody map[string]any server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -114,12 +123,35 @@ func TestProviderChat_AzureUsesMaxCompletionTokens(t *testing.T) { t.Fatalf("Chat() error = %v", err) } - if _, exists := requestBody["max_completion_tokens"]; !exists { - t.Error("request body should contain 'max_completion_tokens'") + if requestBody["max_output_tokens"] == nil { + t.Error("request body should contain 'max_output_tokens'") } if _, exists := requestBody["max_tokens"]; exists { t.Error("request body should not contain 'max_tokens'") } + if _, exists := requestBody["max_completion_tokens"]; exists { + t.Error("request body should not contain 'max_completion_tokens'") + } +} + +func TestProviderChat_AzureStoreIsFalse(t *testing.T) { + var requestBody map[string]any + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + json.NewDecoder(r.Body).Decode(&requestBody) + writeValidResponse(w) + })) + defer server.Close() + + p := NewProvider("test-key", server.URL, "") + _, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "deployment", nil) + if err != nil { + t.Fatalf("Chat() error = %v", err) + } + + if requestBody["store"] != false { + t.Errorf("store = %v, want false", requestBody["store"]) + } } func TestProviderChat_AzureHTTPError(t *testing.T) { @@ -135,27 +167,66 @@ func TestProviderChat_AzureHTTPError(t *testing.T) { } } +func TestProviderChat_AzureParseTextOutput(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + resp := map[string]any{ + "id": "resp_1", + "object": "response", + "status": "completed", + "output": []map[string]any{ + { + "type": "message", + "content": []map[string]any{ + {"type": "output_text", "text": "Hello there!"}, + }, + }, + }, + "usage": map[string]any{ + "input_tokens": 10, "output_tokens": 5, "total_tokens": 15, + "input_tokens_details": map[string]any{"cached_tokens": 0}, + "output_tokens_details": map[string]any{"reasoning_tokens": 0}, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + p := NewProvider("test-key", server.URL, "") + out, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "deployment", nil) + if err != nil { + t.Fatalf("Chat() error = %v", err) + } + if out.Content != "Hello there!" { + t.Errorf("Content = %q, want %q", out.Content, "Hello there!") + } + if out.FinishReason != "stop" { + t.Errorf("FinishReason = %q, want %q", out.FinishReason, "stop") + } + if out.Usage.TotalTokens != 15 { + t.Errorf("TotalTokens = %d, want 15", out.Usage.TotalTokens) + } +} + func TestProviderChat_AzureParseToolCalls(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { resp := map[string]any{ - "choices": []map[string]any{ + "id": "resp_2", + "object": "response", + "status": "completed", + "output": []map[string]any{ { - "message": map[string]any{ - "content": "", - "tool_calls": []map[string]any{ - { - "id": "call_1", - "type": "function", - "function": map[string]any{ - "name": "get_weather", - "arguments": `{"city":"Seattle"}`, - }, - }, - }, - }, - "finish_reason": "tool_calls", + "type": "function_call", + "call_id": "call_1", + "name": "get_weather", + "arguments": `{"city":"Seattle"}`, }, }, + "usage": map[string]any{ + "input_tokens": 10, "output_tokens": 8, "total_tokens": 18, + "input_tokens_details": map[string]any{"cached_tokens": 0}, + "output_tokens_details": map[string]any{"reasoning_tokens": 0}, + }, } w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(resp) @@ -167,13 +238,15 @@ func TestProviderChat_AzureParseToolCalls(t *testing.T) { if err != nil { t.Fatalf("Chat() error = %v", err) } - if len(out.ToolCalls) != 1 { t.Fatalf("len(ToolCalls) = %d, want 1", len(out.ToolCalls)) } if out.ToolCalls[0].Name != "get_weather" { t.Errorf("ToolCalls[0].Name = %q, want %q", out.ToolCalls[0].Name, "get_weather") } + if out.FinishReason != "tool_calls" { + t.Errorf("FinishReason = %q, want %q", out.FinishReason, "tool_calls") + } } func TestProvider_AzureEmptyAPIBase(t *testing.T) { @@ -205,28 +278,103 @@ func TestProvider_AzureNewProviderWithTimeout(t *testing.T) { } } -func TestProviderChat_AzureDeploymentNameEscaped(t *testing.T) { - var capturedPath string +func TestProviderChat_AzureNativeWebSearchInjection(t *testing.T) { + var requestBody map[string]any server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - capturedPath = r.URL.RawPath // use RawPath to see percent-encoding - if capturedPath == "" { - capturedPath = r.URL.Path - } + json.NewDecoder(r.Body).Decode(&requestBody) writeValidResponse(w) })) defer server.Close() + tools := []ToolDefinition{ + { + Type: "function", + Function: protocoltypes.ToolFunctionDefinition{ + Name: "web_search", + Description: "local web search", + Parameters: map[string]any{"type": "object"}, + }, + }, + { + Type: "function", + Function: protocoltypes.ToolFunctionDefinition{ + Name: "read_file", + Description: "read a file", + Parameters: map[string]any{"type": "object"}, + }, + }, + } + p := NewProvider("test-key", server.URL, "") - // Deployment name with characters that could cause path injection - _, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "my deploy/../../admin", nil) + // With native_search=true: user-defined web_search should be replaced by built-in + _, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, tools, "deployment", + map[string]any{"native_search": true}) if err != nil { t.Fatalf("Chat() error = %v", err) } - // The slash and special chars in the deployment name must be escaped, not treated as path separators - if capturedPath == "/openai/deployments/my deploy/../../admin/chat/completions" { - t.Fatal("deployment name was interpolated without escaping — path injection possible") + toolsAny, ok := requestBody["tools"].([]any) + if !ok { + t.Fatal("request body should contain 'tools' array") + } + if len(toolsAny) != 2 { + t.Fatalf("len(tools) = %d, want 2 (read_file + web_search builtin)", len(toolsAny)) + } + + // First tool should be read_file (user-defined web_search was skipped) + firstTool, _ := toolsAny[0].(map[string]any) + if firstTool["name"] != "read_file" { + t.Errorf("first tool name = %v, want %q", firstTool["name"], "read_file") + } + + // Second tool should be built-in web_search + secondTool, _ := toolsAny[1].(map[string]any) + if secondTool["type"] != "web_search" { + t.Errorf("second tool type = %v, want %q", secondTool["type"], "web_search") + } +} + +func TestProviderChat_AzureNoNativeWebSearch(t *testing.T) { + var requestBody map[string]any + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + json.NewDecoder(r.Body).Decode(&requestBody) + writeValidResponse(w) + })) + defer server.Close() + + tools := []ToolDefinition{ + { + Type: "function", + Function: protocoltypes.ToolFunctionDefinition{ + Name: "web_search", + Description: "local web search", + Parameters: map[string]any{"type": "object"}, + }, + }, + } + + p := NewProvider("test-key", server.URL, "") + + // Without native_search: user-defined web_search should be kept as-is + _, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, tools, "deployment", nil) + if err != nil { + t.Fatalf("Chat() error = %v", err) + } + + toolsAny, ok := requestBody["tools"].([]any) + if !ok { + t.Fatal("request body should contain 'tools' array") + } + if len(toolsAny) != 1 { + t.Fatalf("len(tools) = %d, want 1", len(toolsAny)) + } + + // Should be the user-defined function tool, not built-in + tool, _ := toolsAny[0].(map[string]any) + if tool["type"] != "function" { + t.Errorf("tool type = %v, want %q", tool["type"], "function") } } diff --git a/pkg/providers/codex_provider.go b/pkg/providers/codex_provider.go index 4a6d61a4b..d968215cc 100644 --- a/pkg/providers/codex_provider.go +++ b/pkg/providers/codex_provider.go @@ -2,7 +2,6 @@ package providers import ( "context" - "encoding/json" "errors" "fmt" "strings" @@ -13,6 +12,7 @@ import ( "github.com/sipeed/picoclaw/pkg/auth" "github.com/sipeed/picoclaw/pkg/logger" + orc "github.com/sipeed/picoclaw/pkg/providers/openai_responses_common" ) const ( @@ -96,7 +96,7 @@ func (p *CodexProvider) Chat( } // Respect tools.web.prefer_native: only inject native search when the agent - // loop requested it (options["native_search"]), so prefer_native: false + // loop passes options["native_search"]=true, so prefer_native=false means no injection. useNativeSearch := p.enableWebSearch && (options["native_search"] == true) params := buildCodexParams(messages, tools, resolvedModel, options, useNativeSearch) @@ -153,7 +153,7 @@ func (p *CodexProvider) Chat( return nil, fmt.Errorf("codex API call: stream ended without completed response") } - return parseCodexResponse(resp), nil + return orc.ParseResponseFromStruct(resp), nil } func (p *CodexProvider) GetDefaultModel() string { @@ -209,89 +209,14 @@ func resolveCodexModel(model string) (string, string) { func buildCodexParams( messages []Message, tools []ToolDefinition, model string, options map[string]any, enableWebSearch bool, ) responses.ResponseNewParams { - var inputItems responses.ResponseInputParam - var instructions string - - for _, msg := range messages { - switch msg.Role { - case "system": - // Use the full concatenated system prompt (static + dynamic + summary) - // as instructions. This keeps behavior consistent with Anthropic and - // OpenAI-compat adapters where the complete system context lives in - // one place. Prefix caching is handled by prompt_cache_key below, - // not by splitting content across instructions vs input messages. - instructions = msg.Content - case "user": - if msg.ToolCallID != "" { - inputItems = append(inputItems, responses.ResponseInputItemUnionParam{ - OfFunctionCallOutput: &responses.ResponseInputItemFunctionCallOutputParam{ - CallID: msg.ToolCallID, - Output: responses.ResponseInputItemFunctionCallOutputOutputUnionParam{ - OfString: openai.Opt(msg.Content), - }, - }, - }) - } else { - inputItems = append(inputItems, responses.ResponseInputItemUnionParam{ - OfMessage: &responses.EasyInputMessageParam{ - Role: responses.EasyInputMessageRoleUser, - Content: responses.EasyInputMessageContentUnionParam{OfString: openai.Opt(msg.Content)}, - }, - }) - } - case "assistant": - if len(msg.ToolCalls) > 0 { - if msg.Content != "" { - inputItems = append(inputItems, responses.ResponseInputItemUnionParam{ - OfMessage: &responses.EasyInputMessageParam{ - Role: responses.EasyInputMessageRoleAssistant, - Content: responses.EasyInputMessageContentUnionParam{OfString: openai.Opt(msg.Content)}, - }, - }) - } - for _, tc := range msg.ToolCalls { - name, args, ok := resolveCodexToolCall(tc) - if !ok { - logger.WarnCF("provider.codex", "Skipping invalid tool call in history", map[string]any{ - "call_id": tc.ID, - }) - continue - } - inputItems = append(inputItems, responses.ResponseInputItemUnionParam{ - OfFunctionCall: &responses.ResponseFunctionToolCallParam{ - CallID: tc.ID, - Name: name, - Arguments: args, - }, - }) - } - } else { - inputItems = append(inputItems, responses.ResponseInputItemUnionParam{ - OfMessage: &responses.EasyInputMessageParam{ - Role: responses.EasyInputMessageRoleAssistant, - Content: responses.EasyInputMessageContentUnionParam{OfString: openai.Opt(msg.Content)}, - }, - }) - } - case "tool": - inputItems = append(inputItems, responses.ResponseInputItemUnionParam{ - OfFunctionCallOutput: &responses.ResponseInputItemFunctionCallOutputParam{ - CallID: msg.ToolCallID, - Output: responses.ResponseInputItemFunctionCallOutputOutputUnionParam{ - OfString: openai.Opt(msg.Content), - }, - }, - }) - } - } + inputItems, instructions := orc.TranslateMessages(messages) params := responses.ResponseNewParams{ Model: model, Input: responses.ResponseNewParamsInputUnion{ OfInputItemList: inputItems, }, - Instructions: openai.Opt(instructions), - Store: openai.Opt(false), + Store: openai.Opt(false), } if instructions != "" { @@ -309,115 +234,12 @@ func buildCodexParams( } if len(tools) > 0 || enableWebSearch { - params.Tools = translateToolsForCodex(tools, enableWebSearch) + params.Tools = orc.TranslateTools(tools, enableWebSearch) } return params } -func resolveCodexToolCall(tc ToolCall) (name string, arguments string, ok bool) { - name = tc.Name - if name == "" && tc.Function != nil { - name = tc.Function.Name - } - if name == "" { - return "", "", false - } - - if len(tc.Arguments) > 0 { - argsJSON, err := json.Marshal(tc.Arguments) - if err != nil { - return "", "", false - } - return name, string(argsJSON), true - } - - if tc.Function != nil && tc.Function.Arguments != "" { - return name, tc.Function.Arguments, true - } - - return name, "{}", true -} - -func translateToolsForCodex(tools []ToolDefinition, enableWebSearch bool) []responses.ToolUnionParam { - capHint := len(tools) - if enableWebSearch { - capHint++ - } - result := make([]responses.ToolUnionParam, 0, capHint) - for _, t := range tools { - if t.Type != "function" { - continue - } - if enableWebSearch && strings.EqualFold(t.Function.Name, "web_search") { - continue - } - ft := responses.FunctionToolParam{ - Name: t.Function.Name, - Parameters: t.Function.Parameters, - Strict: openai.Opt(false), - } - if t.Function.Description != "" { - ft.Description = openai.Opt(t.Function.Description) - } - result = append(result, responses.ToolUnionParam{OfFunction: &ft}) - } - if enableWebSearch { - result = append(result, responses.ToolParamOfWebSearch(responses.WebSearchToolTypeWebSearch)) - } - return result -} - -func parseCodexResponse(resp *responses.Response) *LLMResponse { - var content strings.Builder - var toolCalls []ToolCall - - for _, item := range resp.Output { - switch item.Type { - case "message": - for _, c := range item.Content { - if c.Type == "output_text" { - content.WriteString(c.Text) - } - } - case "function_call": - var args map[string]any - if err := json.Unmarshal([]byte(item.Arguments), &args); err != nil { - args = map[string]any{"raw": item.Arguments} - } - toolCalls = append(toolCalls, ToolCall{ - ID: item.CallID, - Name: item.Name, - Arguments: args, - }) - } - } - - finishReason := "stop" - if len(toolCalls) > 0 { - finishReason = "tool_calls" - } - if resp.Status == "incomplete" { - finishReason = "length" - } - - var usage *UsageInfo - if resp.Usage.TotalTokens > 0 { - usage = &UsageInfo{ - PromptTokens: int(resp.Usage.InputTokens), - CompletionTokens: int(resp.Usage.OutputTokens), - TotalTokens: int(resp.Usage.TotalTokens), - } - } - - return &LLMResponse{ - Content: content.String(), - ToolCalls: toolCalls, - FinishReason: finishReason, - Usage: usage, - } -} - func createCodexTokenSource() func() (string, string, error) { return func() (string, string, error) { cred, err := auth.GetCredential("openai") diff --git a/pkg/providers/codex_provider_test.go b/pkg/providers/codex_provider_test.go index 3a0da5e3b..ad5748e0c 100644 --- a/pkg/providers/codex_provider_test.go +++ b/pkg/providers/codex_provider_test.go @@ -10,6 +10,8 @@ import ( "github.com/openai/openai-go/v3" openaiopt "github.com/openai/openai-go/v3/option" "github.com/openai/openai-go/v3/responses" + + orc "github.com/sipeed/picoclaw/pkg/providers/openai_responses_common" ) func TestBuildCodexParams_BasicMessage(t *testing.T) { @@ -225,7 +227,7 @@ func TestParseCodexResponse_TextOutput(t *testing.T) { t.Fatalf("unmarshal: %v", err) } - result := parseCodexResponse(&resp) + result := orc.ParseResponseFromStruct(&resp) if result.Content != "Hello there!" { t.Errorf("Content = %q, want %q", result.Content, "Hello there!") } @@ -266,7 +268,7 @@ func TestParseCodexResponse_FunctionCall(t *testing.T) { t.Fatalf("unmarshal: %v", err) } - result := parseCodexResponse(&resp) + result := orc.ParseResponseFromStruct(&resp) if len(result.ToolCalls) != 1 { t.Fatalf("len(ToolCalls) = %d, want 1", len(result.ToolCalls)) } diff --git a/pkg/providers/openai_responses_common/responses_common.go b/pkg/providers/openai_responses_common/responses_common.go new file mode 100644 index 000000000..29133a51e --- /dev/null +++ b/pkg/providers/openai_responses_common/responses_common.go @@ -0,0 +1,291 @@ +// Package openai_responses_common provides shared utilities for providers +// that use the OpenAI Responses API (e.g., Azure, Codex). +package openai_responses_common + +import ( + "encoding/json" + "io" + "strings" + + "github.com/openai/openai-go/v3" + "github.com/openai/openai-go/v3/responses" + + "github.com/sipeed/picoclaw/pkg/providers/protocoltypes" +) + +// TranslateMessages converts internal Message entries to the OpenAI Responses API +// input format. System messages are extracted as instructions (returned separately), +// user/assistant/tool messages become ResponseInputItemUnionParam entries. +// Supports multipart media (images, audio). +func TranslateMessages(messages []protocoltypes.Message) (input responses.ResponseInputParam, instructions string) { + input = make(responses.ResponseInputParam, 0, len(messages)) + + for _, msg := range messages { + switch msg.Role { + case "system": + instructions = msg.Content + case "user": + if msg.ToolCallID != "" { + input = append(input, responses.ResponseInputItemUnionParam{ + OfFunctionCallOutput: &responses.ResponseInputItemFunctionCallOutputParam{ + CallID: msg.ToolCallID, + Output: responses.ResponseInputItemFunctionCallOutputOutputUnionParam{ + OfString: openai.Opt(msg.Content), + }, + }, + }) + } else if len(msg.Media) > 0 { + content := BuildMultipartContent(msg.Content, msg.Media) + input = append(input, responses.ResponseInputItemUnionParam{ + OfInputMessage: &responses.ResponseInputItemMessageParam{ + Role: "user", + Content: content, + }, + }) + } else { + input = append(input, responses.ResponseInputItemUnionParam{ + OfMessage: &responses.EasyInputMessageParam{ + Role: responses.EasyInputMessageRoleUser, + Content: responses.EasyInputMessageContentUnionParam{OfString: openai.Opt(msg.Content)}, + }, + }) + } + case "assistant": + if len(msg.ToolCalls) > 0 { + if msg.Content != "" { + input = append(input, responses.ResponseInputItemUnionParam{ + OfMessage: &responses.EasyInputMessageParam{ + Role: responses.EasyInputMessageRoleAssistant, + Content: responses.EasyInputMessageContentUnionParam{OfString: openai.Opt(msg.Content)}, + }, + }) + } + for _, tc := range msg.ToolCalls { + name, args, ok := ResolveToolCall(tc) + if !ok { + continue + } + input = append(input, responses.ResponseInputItemUnionParam{ + OfFunctionCall: &responses.ResponseFunctionToolCallParam{ + CallID: tc.ID, + Name: name, + Arguments: args, + }, + }) + } + } else { + input = append(input, responses.ResponseInputItemUnionParam{ + OfMessage: &responses.EasyInputMessageParam{ + Role: responses.EasyInputMessageRoleAssistant, + Content: responses.EasyInputMessageContentUnionParam{OfString: openai.Opt(msg.Content)}, + }, + }) + } + case "tool": + input = append(input, responses.ResponseInputItemUnionParam{ + OfFunctionCallOutput: &responses.ResponseInputItemFunctionCallOutputParam{ + CallID: msg.ToolCallID, + Output: responses.ResponseInputItemFunctionCallOutputOutputUnionParam{ + OfString: openai.Opt(msg.Content), + }, + }, + }) + } + } + + return input, instructions +} + +// BuildMultipartContent constructs a ResponseInputMessageContentListParam from +// text content and media URLs (data:image/... and data:audio/... URIs). +func BuildMultipartContent(text string, media []string) responses.ResponseInputMessageContentListParam { + parts := make(responses.ResponseInputMessageContentListParam, 0, 1+len(media)) + + if text != "" { + parts = append(parts, responses.ResponseInputContentUnionParam{ + OfInputText: &responses.ResponseInputTextParam{ + Text: text, + }, + }) + } + + for _, mediaURL := range media { + if strings.HasPrefix(mediaURL, "data:image/") { + parts = append(parts, responses.ResponseInputContentUnionParam{ + OfInputImage: &responses.ResponseInputImageParam{ + ImageURL: openai.Opt(mediaURL), + Detail: responses.ResponseInputImageDetailAuto, + }, + }) + } else if strings.HasPrefix(mediaURL, "data:audio/") { + if format, data, ok := ParseDataAudioURL(mediaURL); ok { + parts = append(parts, responses.ResponseInputContentUnionParam{ + OfInputFile: &responses.ResponseInputFileParam{ + FileData: openai.Opt(data), + Filename: openai.Opt("audio." + format), + }, + }) + } + } + } + + return parts +} + +// ParseDataAudioURL extracts the format and base64 data from a data:audio/... URL. +func ParseDataAudioURL(mediaURL string) (format, data string, ok bool) { + if !strings.HasPrefix(mediaURL, "data:audio/") { + return "", "", false + } + payload := strings.TrimPrefix(mediaURL, "data:audio/") + meta, data, found := strings.Cut(payload, ",") + if !found { + return "", "", false + } + format, _, _ = strings.Cut(meta, ";") + format = strings.TrimSpace(format) + data = strings.TrimSpace(data) + if format == "" || data == "" { + return "", "", false + } + return format, data, true +} + +// ResolveToolCall extracts the function name and JSON arguments string from a ToolCall. +// Returns ok=false if the tool call has no name or if arguments fail to marshal. +func ResolveToolCall(tc protocoltypes.ToolCall) (name string, arguments string, ok bool) { + name = tc.Name + if name == "" && tc.Function != nil { + name = tc.Function.Name + } + if name == "" { + return "", "", false + } + + if len(tc.Arguments) > 0 { + argsJSON, err := json.Marshal(tc.Arguments) + if err != nil { + return "", "", false + } + return name, string(argsJSON), true + } + + if tc.Function != nil && tc.Function.Arguments != "" { + return name, tc.Function.Arguments, true + } + + return name, "{}", true +} + +// TranslateTools converts internal ToolDefinition entries to the OpenAI Responses API +// tool format. If enableWebSearch is true, a web_search tool is appended and any +// user-defined tool named "web_search" is skipped to avoid duplicates. +func TranslateTools(tools []protocoltypes.ToolDefinition, enableWebSearch bool) []responses.ToolUnionParam { + capHint := len(tools) + if enableWebSearch { + capHint++ + } + result := make([]responses.ToolUnionParam, 0, capHint) + + for _, t := range tools { + if t.Type != "function" { + continue + } + if enableWebSearch && strings.EqualFold(t.Function.Name, "web_search") { + continue + } + ft := responses.FunctionToolParam{ + Name: t.Function.Name, + Parameters: t.Function.Parameters, + Strict: openai.Opt(false), + } + if t.Function.Description != "" { + ft.Description = openai.Opt(t.Function.Description) + } + result = append(result, responses.ToolUnionParam{OfFunction: &ft}) + } + + if enableWebSearch { + result = append(result, responses.ToolParamOfWebSearch(responses.WebSearchToolTypeWebSearch)) + } + + return result +} + +// ParseResponseBody parses an OpenAI Responses API JSON body into an LLMResponse. +// Handles output item types: "message" (output_text + refusal), "function_call", and "reasoning". +func ParseResponseBody(body io.Reader) (*protocoltypes.LLMResponse, error) { + var apiResp responses.Response + if err := json.NewDecoder(body).Decode(&apiResp); err != nil { + return nil, err + } + + return parseResponse(&apiResp), nil +} + +// ParseResponseFromStruct converts a decoded responses.Response into an LLMResponse. +// Used by providers that receive the Response struct directly (e.g., via streaming SDK). +func ParseResponseFromStruct(resp *responses.Response) *protocoltypes.LLMResponse { + return parseResponse(resp) +} + +// parseResponse is the shared implementation for extracting LLMResponse fields +// from a decoded responses.Response. +func parseResponse(apiResp *responses.Response) *protocoltypes.LLMResponse { + var content strings.Builder + var reasoningContent strings.Builder + var toolCalls []protocoltypes.ToolCall + + for _, item := range apiResp.Output { + switch item.Type { + case "message": + for _, c := range item.Content { + switch c.Type { + case "output_text": + content.WriteString(c.Text) + case "refusal": + content.WriteString(c.Refusal) + } + } + case "function_call": + var args map[string]any + if err := json.Unmarshal([]byte(item.Arguments), &args); err != nil { + args = map[string]any{"raw": item.Arguments} + } + toolCalls = append(toolCalls, protocoltypes.ToolCall{ + ID: item.CallID, + Name: item.Name, + Arguments: args, + }) + case "reasoning": + for _, s := range item.Summary { + reasoningContent.WriteString(s.Text) + } + } + } + + finishReason := "stop" + if len(toolCalls) > 0 { + finishReason = "tool_calls" + } + if apiResp.Status == "incomplete" { + finishReason = "length" + } + + var usage *protocoltypes.UsageInfo + if apiResp.Usage.TotalTokens > 0 { + usage = &protocoltypes.UsageInfo{ + PromptTokens: int(apiResp.Usage.InputTokens), + CompletionTokens: int(apiResp.Usage.OutputTokens), + TotalTokens: int(apiResp.Usage.TotalTokens), + } + } + + return &protocoltypes.LLMResponse{ + Content: content.String(), + ReasoningContent: reasoningContent.String(), + ToolCalls: toolCalls, + FinishReason: finishReason, + Usage: usage, + } +} diff --git a/pkg/providers/openai_responses_common/responses_common_test.go b/pkg/providers/openai_responses_common/responses_common_test.go new file mode 100644 index 000000000..be10e8427 --- /dev/null +++ b/pkg/providers/openai_responses_common/responses_common_test.go @@ -0,0 +1,593 @@ +package openai_responses_common + +import ( + "encoding/json" + "strings" + "testing" + + "github.com/sipeed/picoclaw/pkg/providers/protocoltypes" +) + +// --- TranslateMessages tests --- + +func TestTranslateMessages_SystemExtractedAsInstructions(t *testing.T) { + msgs := []protocoltypes.Message{ + {Role: "system", Content: "You are helpful"}, + {Role: "user", Content: "Hi"}, + } + input, instructions := TranslateMessages(msgs) + if instructions != "You are helpful" { + t.Errorf("instructions = %q, want %q", instructions, "You are helpful") + } + if len(input) != 1 { + t.Fatalf("len(input) = %d, want 1", len(input)) + } + if input[0].OfMessage == nil { + t.Fatal("expected user message item") + } +} + +func TestTranslateMessages_UserTextMessage(t *testing.T) { + msgs := []protocoltypes.Message{ + {Role: "user", Content: "Hello"}, + } + input, instructions := TranslateMessages(msgs) + if instructions != "" { + t.Errorf("instructions = %q, want empty", instructions) + } + if len(input) != 1 { + t.Fatalf("len(input) = %d, want 1", len(input)) + } + if input[0].OfMessage == nil { + t.Fatal("expected EasyInputMessage") + } + if string(input[0].OfMessage.Role) != "user" { + t.Errorf("role = %q, want %q", input[0].OfMessage.Role, "user") + } +} + +func TestTranslateMessages_UserWithToolCallID(t *testing.T) { + msgs := []protocoltypes.Message{ + {Role: "user", Content: `{"temp":72}`, ToolCallID: "call_1"}, + } + input, _ := TranslateMessages(msgs) + if len(input) != 1 { + t.Fatalf("len(input) = %d, want 1", len(input)) + } + if input[0].OfFunctionCallOutput == nil { + t.Fatal("expected FunctionCallOutput for user with ToolCallID") + } + if input[0].OfFunctionCallOutput.CallID != "call_1" { + t.Errorf("CallID = %q, want %q", input[0].OfFunctionCallOutput.CallID, "call_1") + } +} + +func TestTranslateMessages_UserWithMedia(t *testing.T) { + msgs := []protocoltypes.Message{ + {Role: "user", Content: "Describe this", Media: []string{"data:image/png;base64,abc123"}}, + } + input, _ := TranslateMessages(msgs) + if len(input) != 1 { + t.Fatalf("len(input) = %d, want 1", len(input)) + } + if input[0].OfInputMessage == nil { + t.Fatal("expected InputMessage for multipart content") + } + if input[0].OfInputMessage.Role != "user" { + t.Errorf("role = %q, want %q", input[0].OfInputMessage.Role, "user") + } +} + +func TestTranslateMessages_AssistantWithToolCalls(t *testing.T) { + msgs := []protocoltypes.Message{ + {Role: "user", Content: "Weather?"}, + { + Role: "assistant", + Content: "Let me check", + ToolCalls: []protocoltypes.ToolCall{ + {ID: "call_1", Name: "get_weather", Arguments: map[string]any{"city": "SF"}}, + }, + }, + {Role: "tool", Content: `{"temp":72}`, ToolCallID: "call_1"}, + } + input, _ := TranslateMessages(msgs) + // user + assistant text + function_call + tool output = 4 items + if len(input) != 4 { + t.Fatalf("len(input) = %d, want 4", len(input)) + } + // item[1] = assistant text + if input[1].OfMessage == nil { + t.Fatal("expected assistant text message") + } + // item[2] = function call + if input[2].OfFunctionCall == nil { + t.Fatal("expected function call") + } + if input[2].OfFunctionCall.Name != "get_weather" { + t.Errorf("function name = %q, want %q", input[2].OfFunctionCall.Name, "get_weather") + } + // item[3] = tool output + if input[3].OfFunctionCallOutput == nil { + t.Fatal("expected function call output") + } +} + +func TestTranslateMessages_AssistantWithoutToolCalls(t *testing.T) { + msgs := []protocoltypes.Message{ + {Role: "assistant", Content: "Sure thing"}, + } + input, _ := TranslateMessages(msgs) + if len(input) != 1 { + t.Fatalf("len(input) = %d, want 1", len(input)) + } + if input[0].OfMessage == nil { + t.Fatal("expected EasyInputMessage for assistant without tool calls") + } +} + +func TestTranslateMessages_ToolMessage(t *testing.T) { + msgs := []protocoltypes.Message{ + {Role: "tool", Content: "result data", ToolCallID: "call_99"}, + } + input, _ := TranslateMessages(msgs) + if len(input) != 1 { + t.Fatalf("len(input) = %d, want 1", len(input)) + } + if input[0].OfFunctionCallOutput == nil { + t.Fatal("expected FunctionCallOutput") + } + if input[0].OfFunctionCallOutput.CallID != "call_99" { + t.Errorf("CallID = %q, want %q", input[0].OfFunctionCallOutput.CallID, "call_99") + } +} + +// --- ResolveToolCall tests --- + +func TestResolveToolCall_FromNameAndArguments(t *testing.T) { + tc := protocoltypes.ToolCall{ + Name: "get_weather", + Arguments: map[string]any{"city": "SF"}, + } + name, args, ok := ResolveToolCall(tc) + if !ok { + t.Fatal("expected ok=true") + } + if name != "get_weather" { + t.Errorf("name = %q, want %q", name, "get_weather") + } + if !strings.Contains(args, "SF") { + t.Errorf("args = %q, want to contain SF", args) + } +} + +func TestResolveToolCall_FromFunctionField(t *testing.T) { + tc := protocoltypes.ToolCall{ + ID: "call_1", + Function: &protocoltypes.FunctionCall{ + Name: "read_file", + Arguments: `{"path":"README.md"}`, + }, + } + name, args, ok := ResolveToolCall(tc) + if !ok { + t.Fatal("expected ok=true") + } + if name != "read_file" { + t.Errorf("name = %q, want %q", name, "read_file") + } + if args != `{"path":"README.md"}` { + t.Errorf("args = %q, want %q", args, `{"path":"README.md"}`) + } +} + +func TestResolveToolCall_EmptyName(t *testing.T) { + tc := protocoltypes.ToolCall{} + _, _, ok := ResolveToolCall(tc) + if ok { + t.Error("expected ok=false for empty tool call") + } +} + +func TestResolveToolCall_NoArgsFallsBackToEmptyObject(t *testing.T) { + tc := protocoltypes.ToolCall{Name: "do_something"} + name, args, ok := ResolveToolCall(tc) + if !ok { + t.Fatal("expected ok=true") + } + if name != "do_something" { + t.Errorf("name = %q, want %q", name, "do_something") + } + if args != "{}" { + t.Errorf("args = %q, want %q", args, "{}") + } +} + +// --- TranslateTools tests --- + +func TestTranslateTools_FunctionTools(t *testing.T) { + tools := []protocoltypes.ToolDefinition{ + { + Type: "function", + Function: protocoltypes.ToolFunctionDefinition{ + Name: "get_weather", + Description: "Get weather", + Parameters: map[string]any{"type": "object"}, + }, + }, + } + result := TranslateTools(tools, false) + if len(result) != 1 { + t.Fatalf("len(result) = %d, want 1", len(result)) + } + if result[0].OfFunction == nil { + t.Fatal("expected function tool") + } + if result[0].OfFunction.Name != "get_weather" { + t.Errorf("name = %q, want %q", result[0].OfFunction.Name, "get_weather") + } +} + +func TestTranslateTools_SkipsNonFunction(t *testing.T) { + tools := []protocoltypes.ToolDefinition{ + {Type: "not_function"}, + } + result := TranslateTools(tools, false) + if len(result) != 0 { + t.Errorf("len(result) = %d, want 0", len(result)) + } +} + +func TestTranslateTools_WebSearchAppended(t *testing.T) { + result := TranslateTools(nil, true) + if len(result) != 1 { + t.Fatalf("len(result) = %d, want 1", len(result)) + } + if result[0].OfWebSearch == nil { + t.Fatal("expected web_search tool") + } +} + +func TestTranslateTools_WebSearchReplacesUserDefined(t *testing.T) { + tools := []protocoltypes.ToolDefinition{ + { + Type: "function", + Function: protocoltypes.ToolFunctionDefinition{ + Name: "web_search", + Parameters: map[string]any{"type": "object"}, + }, + }, + { + Type: "function", + Function: protocoltypes.ToolFunctionDefinition{ + Name: "read_file", + Parameters: map[string]any{"type": "object"}, + }, + }, + } + result := TranslateTools(tools, true) + if len(result) != 2 { + t.Fatalf("len(result) = %d, want 2", len(result)) + } + if result[0].OfFunction == nil || result[0].OfFunction.Name != "read_file" { + t.Errorf("first tool should be read_file, got %v", result[0]) + } + if result[1].OfWebSearch == nil { + t.Error("second tool should be web_search") + } +} + +func TestTranslateTools_DescriptionOmittedWhenEmpty(t *testing.T) { + tools := []protocoltypes.ToolDefinition{ + { + Type: "function", + Function: protocoltypes.ToolFunctionDefinition{ + Name: "no_desc", + Parameters: map[string]any{"type": "object"}, + }, + }, + } + result := TranslateTools(tools, false) + if len(result) != 1 { + t.Fatalf("len(result) = %d, want 1", len(result)) + } + if result[0].OfFunction.Description.Valid() { + t.Error("Description should not be set when empty") + } +} + +// --- ParseResponseBody tests --- + +func TestParseResponseBody_TextOutput(t *testing.T) { + body := strings.NewReader(`{ + "id": "resp_123", + "object": "response", + "status": "completed", + "output": [ + { + "type": "message", + "content": [{"type": "output_text", "text": "Hello!"}] + } + ], + "usage": { + "input_tokens": 10, + "output_tokens": 5, + "total_tokens": 15, + "input_tokens_details": {"cached_tokens": 0}, + "output_tokens_details": {"reasoning_tokens": 0} + } + }`) + + result, err := ParseResponseBody(body) + if err != nil { + t.Fatalf("ParseResponseBody error: %v", err) + } + if result.Content != "Hello!" { + t.Errorf("Content = %q, want %q", result.Content, "Hello!") + } + if result.FinishReason != "stop" { + t.Errorf("FinishReason = %q, want %q", result.FinishReason, "stop") + } + if result.Usage.TotalTokens != 15 { + t.Errorf("TotalTokens = %d, want 15", result.Usage.TotalTokens) + } +} + +func TestParseResponseBody_FunctionCall(t *testing.T) { + body := strings.NewReader(`{ + "id": "resp_456", + "object": "response", + "status": "completed", + "output": [ + { + "type": "function_call", + "call_id": "call_abc", + "name": "get_weather", + "arguments": "{\"city\":\"SF\"}" + } + ], + "usage": { + "input_tokens": 10, + "output_tokens": 8, + "total_tokens": 18, + "input_tokens_details": {"cached_tokens": 0}, + "output_tokens_details": {"reasoning_tokens": 0} + } + }`) + + result, err := ParseResponseBody(body) + if err != nil { + t.Fatalf("ParseResponseBody error: %v", err) + } + if len(result.ToolCalls) != 1 { + t.Fatalf("len(ToolCalls) = %d, want 1", len(result.ToolCalls)) + } + if result.ToolCalls[0].Name != "get_weather" { + t.Errorf("Name = %q, want %q", result.ToolCalls[0].Name, "get_weather") + } + if result.ToolCalls[0].ID != "call_abc" { + t.Errorf("ID = %q, want %q", result.ToolCalls[0].ID, "call_abc") + } + if result.FinishReason != "tool_calls" { + t.Errorf("FinishReason = %q, want %q", result.FinishReason, "tool_calls") + } +} + +func TestParseResponseBody_Reasoning(t *testing.T) { + body := strings.NewReader(`{ + "id": "resp_789", + "object": "response", + "status": "completed", + "output": [ + { + "type": "reasoning", + "id": "rs_1", + "summary": [{"type": "summary_text", "text": "Thinking about it..."}] + }, + { + "type": "message", + "content": [{"type": "output_text", "text": "The answer is 42."}] + } + ], + "usage": { + "input_tokens": 10, + "output_tokens": 20, + "total_tokens": 30, + "input_tokens_details": {"cached_tokens": 0}, + "output_tokens_details": {"reasoning_tokens": 10} + } + }`) + + result, err := ParseResponseBody(body) + if err != nil { + t.Fatalf("ParseResponseBody error: %v", err) + } + if result.Content != "The answer is 42." { + t.Errorf("Content = %q, want %q", result.Content, "The answer is 42.") + } + if result.ReasoningContent != "Thinking about it..." { + t.Errorf("ReasoningContent = %q, want %q", result.ReasoningContent, "Thinking about it...") + } +} + +func TestParseResponseBody_Refusal(t *testing.T) { + body := strings.NewReader(`{ + "id": "resp_ref", + "object": "response", + "status": "completed", + "output": [ + { + "type": "message", + "content": [{"type": "refusal", "refusal": "I cannot help with that."}] + } + ], + "usage": { + "input_tokens": 5, + "output_tokens": 5, + "total_tokens": 10, + "input_tokens_details": {"cached_tokens": 0}, + "output_tokens_details": {"reasoning_tokens": 0} + } + }`) + + result, err := ParseResponseBody(body) + if err != nil { + t.Fatalf("ParseResponseBody error: %v", err) + } + if result.Content != "I cannot help with that." { + t.Errorf("Content = %q, want %q", result.Content, "I cannot help with that.") + } +} + +func TestParseResponseBody_IncompleteStatus(t *testing.T) { + body := strings.NewReader(`{ + "id": "resp_inc", + "object": "response", + "status": "incomplete", + "output": [ + { + "type": "message", + "content": [{"type": "output_text", "text": "partial"}] + } + ], + "usage": {"input_tokens": 5, "output_tokens": 2, "total_tokens": 7, + "input_tokens_details": {"cached_tokens": 0}, + "output_tokens_details": {"reasoning_tokens": 0}} + }`) + + result, err := ParseResponseBody(body) + if err != nil { + t.Fatalf("error: %v", err) + } + if result.FinishReason != "length" { + t.Errorf("FinishReason = %q, want %q", result.FinishReason, "length") + } +} + +func TestParseResponseBody_FailedStatus(t *testing.T) { + body := strings.NewReader(`{ + "id": "resp_fail", + "object": "response", + "status": "failed", + "output": [], + "usage": {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0, + "input_tokens_details": {"cached_tokens": 0}, + "output_tokens_details": {"reasoning_tokens": 0}} + }`) + + result, err := ParseResponseBody(body) + if err != nil { + t.Fatalf("error: %v", err) + } + // failed/canceled statuses are not specially mapped; they fall through to "stop" + if result.FinishReason != "stop" { + t.Errorf("FinishReason = %q, want %q", result.FinishReason, "stop") + } +} + +// --- ParseDataAudioURL tests --- + +func TestParseDataAudioURL_Valid(t *testing.T) { + format, data, ok := ParseDataAudioURL("data:audio/mp3;base64,SGVsbG8=") + if !ok { + t.Fatal("expected ok=true") + } + if format != "mp3" { + t.Errorf("format = %q, want %q", format, "mp3") + } + if data != "SGVsbG8=" { + t.Errorf("data = %q, want %q", data, "SGVsbG8=") + } +} + +func TestParseDataAudioURL_NotAudio(t *testing.T) { + _, _, ok := ParseDataAudioURL("data:image/png;base64,abc") + if ok { + t.Error("expected ok=false for non-audio URL") + } +} + +func TestParseDataAudioURL_MalformedNoComma(t *testing.T) { + _, _, ok := ParseDataAudioURL("data:audio/mp3;base64") + if ok { + t.Error("expected ok=false for malformed URL") + } +} + +func TestParseDataAudioURL_EmptyData(t *testing.T) { + _, _, ok := ParseDataAudioURL("data:audio/mp3;base64,") + if ok { + t.Error("expected ok=false for empty data") + } +} + +// --- BuildMultipartContent tests --- + +func TestBuildMultipartContent_TextOnly(t *testing.T) { + parts := BuildMultipartContent("hello", nil) + if len(parts) != 1 { + t.Fatalf("len(parts) = %d, want 1", len(parts)) + } + if parts[0].OfInputText == nil { + t.Fatal("expected text part") + } +} + +func TestBuildMultipartContent_TextAndImage(t *testing.T) { + parts := BuildMultipartContent("describe", []string{"data:image/png;base64,abc"}) + if len(parts) != 2 { + t.Fatalf("len(parts) = %d, want 2", len(parts)) + } + if parts[0].OfInputText == nil { + t.Error("first part should be text") + } + if parts[1].OfInputImage == nil { + t.Error("second part should be image") + } +} + +func TestBuildMultipartContent_AudioFile(t *testing.T) { + parts := BuildMultipartContent("", []string{"data:audio/wav;base64,AAAA"}) + if len(parts) != 1 { + t.Fatalf("len(parts) = %d, want 1", len(parts)) + } + if parts[0].OfInputFile == nil { + t.Fatal("expected file part for audio") + } +} + +func TestBuildMultipartContent_EmptyTextSkipped(t *testing.T) { + parts := BuildMultipartContent("", []string{"data:image/png;base64,abc"}) + if len(parts) != 1 { + t.Fatalf("len(parts) = %d, want 1", len(parts)) + } + if parts[0].OfInputImage == nil { + t.Error("should only have image part") + } +} + +// --- JSON serialization sanity checks --- + +func TestTranslateTools_SerializesToJSON(t *testing.T) { + tools := []protocoltypes.ToolDefinition{ + { + Type: "function", + Function: protocoltypes.ToolFunctionDefinition{ + Name: "test_tool", + Description: "A test", + Parameters: map[string]any{"type": "object"}, + }, + }, + } + result := TranslateTools(tools, true) + data, err := json.Marshal(result) + if err != nil { + t.Fatalf("json.Marshal error: %v", err) + } + s := string(data) + if !strings.Contains(s, "test_tool") { + t.Errorf("JSON should contain test_tool, got: %s", s) + } + if !strings.Contains(s, "web_search") { + t.Errorf("JSON should contain web_search, got: %s", s) + } +}