From 6fbd7e0a3fb04929f21f8d1b3ebd2261057c144f Mon Sep 17 00:00:00 2001 From: lc6464 <64722907+lc6464@users.noreply.github.com> Date: Sat, 11 Apr 2026 12:02:58 +0800 Subject: [PATCH] fix(gemini): align thoughtSignature and stream tool IDs --- pkg/providers/gemini_provider.go | 24 +++-- pkg/providers/gemini_provider_test.go | 121 ++++++++++++++++++++++++++ 2 files changed, 139 insertions(+), 6 deletions(-) diff --git a/pkg/providers/gemini_provider.go b/pkg/providers/gemini_provider.go index 96f8da66d..561387534 100644 --- a/pkg/providers/gemini_provider.go +++ b/pkg/providers/gemini_provider.go @@ -226,7 +226,6 @@ func (p *GeminiProvider) buildRequestBody( } if thoughtSignature != "" { part.ThoughtSignature = thoughtSignature - part.ThoughtSignatureSnake = thoughtSignature } content.Parts = append(content.Parts, part) } @@ -508,12 +507,25 @@ func parseGeminiStreamResponse( } if part.FunctionCall != nil { tc := buildGeminiToolCall(part) - key := tc.ID - if strings.TrimSpace(key) == "" { - fallbackIndex++ - key = fmt.Sprintf("%s#%d", tc.Name, fallbackIndex) - tc.ID = key + if strings.TrimSpace(tc.Name) == "" { + continue } + + key := strings.TrimSpace(part.FunctionCall.ID) + if key == "" { + if len(toolCallOrder) > 0 { + lastKey := toolCallOrder[len(toolCallOrder)-1] + if lastTC, exists := toolCallsByID[lastKey]; exists && lastTC.Name == tc.Name { + key = lastKey + } + } + if key == "" { + fallbackIndex++ + key = fmt.Sprintf("%s#%d", tc.Name, fallbackIndex) + } + } + + tc.ID = key if _, exists := toolCallsByID[key]; !exists { toolCallOrder = append(toolCallOrder, key) } diff --git a/pkg/providers/gemini_provider_test.go b/pkg/providers/gemini_provider_test.go index 3c90cc4e2..a0ab748eb 100644 --- a/pkg/providers/gemini_provider_test.go +++ b/pkg/providers/gemini_provider_test.go @@ -289,6 +289,127 @@ func TestGeminiProvider_ChatStreamReturnsErrorOnInvalidDataFrame(t *testing.T) { } } +func TestGeminiProvider_BuildRequestBody_UsesCamelCaseThoughtSignatureOnly(t *testing.T) { + provider := NewGeminiProvider("test-key", "https://example.com/v1beta", "", "", 0, nil, nil) + + body := provider.buildRequestBody( + []Message{{ + Role: "assistant", + ToolCalls: []ToolCall{{ + ID: "call_1", + Name: "search", + Arguments: map[string]any{"q": "hello"}, + Function: &FunctionCall{ + Name: "search", + Arguments: `{"q":"hello"}`, + ThoughtSignature: "sig-1", + }, + }}, + }}, + nil, + "gemini-2.5-flash", + nil, + ) + + raw, err := json.Marshal(body) + if err != nil { + t.Fatalf("marshal request body: %v", err) + } + jsonBody := string(raw) + + if !strings.Contains(jsonBody, `"thoughtSignature":"sig-1"`) { + t.Fatalf("request body = %s, expected camelCase thoughtSignature", jsonBody) + } + if strings.Contains(jsonBody, `"thought_signature"`) { + t.Fatalf("request body = %s, unexpected snake_case thought_signature", jsonBody) + } +} + +func TestGeminiProvider_ChatStreamCoalescesToolCallWithoutWireID(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + flusher, ok := w.(http.Flusher) + if !ok { + t.Fatal("response writer is not flushable") + } + + chunks := []map[string]any{ + { + "candidates": []any{map[string]any{ + "content": map[string]any{ + "parts": []any{ + map[string]any{ + "functionCall": map[string]any{ + "name": "search", + "args": map[string]any{"q": "first"}, + }, + }, + }, + }, + }}, + }, + { + "candidates": []any{map[string]any{ + "content": map[string]any{ + "parts": []any{ + map[string]any{ + "functionCall": map[string]any{ + "name": "search", + "args": map[string]any{"q": "second"}, + }, + }, + }, + }, + "finishReason": "STOP", + }}, + }, + } + + for _, chunk := range chunks { + raw, err := json.Marshal(chunk) + if err != nil { + t.Fatalf("marshal chunk: %v", err) + } + if _, err := fmt.Fprintf(w, "data: %s\n\n", raw); err != nil { + t.Fatalf("write chunk: %v", err) + } + flusher.Flush() + } + _, _ = fmt.Fprint(w, "data: [DONE]\n\n") + flusher.Flush() + })) + defer server.Close() + + provider := NewGeminiProvider("test-key", server.URL, "", "", 0, nil, nil) + resp, err := provider.ChatStream( + t.Context(), + []Message{{Role: "user", Content: "hello"}}, + nil, + "gemini-2.5-flash", + nil, + nil, + ) + if err != nil { + t.Fatalf("ChatStream() error = %v", err) + } + if len(resp.ToolCalls) != 1 { + t.Fatalf("ToolCalls len = %d, want 1", len(resp.ToolCalls)) + } + tc := resp.ToolCalls[0] + if tc.ID != "search#1" { + t.Fatalf("ToolCall ID = %q, want %q", tc.ID, "search#1") + } + if tc.Name != "search" { + t.Fatalf("ToolCall Name = %q, want %q", tc.Name, "search") + } + if argQ, ok := tc.Arguments["q"].(string); !ok || argQ != "second" { + t.Fatalf("ToolCall Arguments = %#v, want q=second", tc.Arguments) + } + if resp.FinishReason != "tool_calls" { + t.Fatalf("FinishReason = %q, want %q", resp.FinishReason, "tool_calls") + } +} + func TestGeminiProvider_BuildRequestBodyIncludesMediaAndThinkingConfig(t *testing.T) { provider := NewGeminiProvider("test-key", "https://example.com/v1beta", "", "", 0, nil, nil)