diff --git a/pkg/providers/openai_compat/provider.go b/pkg/providers/openai_compat/provider.go index aa4fa9e6d..d5f4bdfce 100644 --- a/pkg/providers/openai_compat/provider.go +++ b/pkg/providers/openai_compat/provider.go @@ -284,8 +284,8 @@ func parseResponse(body io.Reader) (*LLMResponse, error) { ID string `json:"id"` Type string `json:"type"` Function *struct { - Name string `json:"name"` - Arguments string `json:"arguments"` + Name string `json:"name"` + Arguments json.RawMessage `json:"arguments"` } `json:"function"` ExtraContent *struct { Google *struct { @@ -324,12 +324,7 @@ func parseResponse(body io.Reader) (*LLMResponse, error) { if tc.Function != nil { name = tc.Function.Name - if tc.Function.Arguments != "" { - if err := json.Unmarshal([]byte(tc.Function.Arguments), &arguments); err != nil { - log.Printf("openai_compat: failed to decode tool call arguments for %q: %v", name, err) - arguments["raw"] = tc.Function.Arguments - } - } + arguments = decodeToolCallArguments(tc.Function.Arguments, name) } // Build ToolCall with ExtraContent for Gemini 3 thought_signature persistence @@ -362,6 +357,39 @@ func parseResponse(body io.Reader) (*LLMResponse, error) { }, nil } +func decodeToolCallArguments(raw json.RawMessage, name string) map[string]any { + arguments := make(map[string]any) + raw = bytes.TrimSpace(raw) + if len(raw) == 0 || bytes.Equal(raw, []byte("null")) { + return arguments + } + + var decoded any + if err := json.Unmarshal(raw, &decoded); err != nil { + log.Printf("openai_compat: failed to decode tool call arguments payload for %q: %v", name, err) + arguments["raw"] = string(raw) + return arguments + } + + switch v := decoded.(type) { + case string: + if strings.TrimSpace(v) == "" { + return arguments + } + if err := json.Unmarshal([]byte(v), &arguments); err != nil { + log.Printf("openai_compat: failed to decode tool call arguments for %q: %v", name, err) + arguments["raw"] = v + } + return arguments + case map[string]any: + return v + default: + log.Printf("openai_compat: unsupported tool call arguments type for %q: %T", name, decoded) + arguments["raw"] = string(raw) + return arguments + } +} + // openaiMessage is the wire-format message for OpenAI-compatible APIs. // It mirrors protocoltypes.Message but omits SystemParts, which is an // internal field that would be unknown to third-party endpoints. diff --git a/pkg/providers/openai_compat/provider_test.go b/pkg/providers/openai_compat/provider_test.go index 5581146fe..39aff1d1a 100644 --- a/pkg/providers/openai_compat/provider_test.go +++ b/pkg/providers/openai_compat/provider_test.go @@ -108,6 +108,55 @@ func TestProviderChat_ParsesToolCalls(t *testing.T) { } } +func TestProviderChat_ParsesToolCallsWithObjectArguments(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + resp := map[string]any{ + "choices": []map[string]any{ + { + "message": map[string]any{ + "content": "", + "tool_calls": []map[string]any{ + { + "id": "call_1", + "type": "function", + "function": map[string]any{ + "name": "get_weather", + "arguments": map[string]any{ + "city": "SF", + "metric": true, + }, + }, + }, + }, + }, + "finish_reason": "tool_calls", + }, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + p := NewProvider("key", server.URL, "") + out, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "gpt-4o", nil) + if err != nil { + t.Fatalf("Chat() error = %v", err) + } + if len(out.ToolCalls) != 1 { + t.Fatalf("len(ToolCalls) = %d, want 1", len(out.ToolCalls)) + } + if out.ToolCalls[0].Name != "get_weather" { + t.Fatalf("ToolCalls[0].Name = %q, want %q", out.ToolCalls[0].Name, "get_weather") + } + if out.ToolCalls[0].Arguments["city"] != "SF" { + t.Fatalf("ToolCalls[0].Arguments[city] = %v, want SF", out.ToolCalls[0].Arguments["city"]) + } + if out.ToolCalls[0].Arguments["metric"] != true { + t.Fatalf("ToolCalls[0].Arguments[metric] = %v, want true", out.ToolCalls[0].Arguments["metric"]) + } +} + func TestProviderChat_ParsesReasoningContent(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { resp := map[string]any{