fix(gemini): harden dedicated provider compatibility

This commit is contained in:
lc6464
2026-04-11 00:50:24 +08:00
parent c8bac699fe
commit 459e78c076
7 changed files with 1342 additions and 5 deletions
+20 -2
View File
@@ -114,7 +114,7 @@ func ResolveAPIBase(cfg *config.ModelConfig) string {
// CreateProviderFromConfig creates a provider based on the ModelConfig.
// It uses the protocol prefix in the Model field to determine which provider to create.
// Supported protocol families include OpenAI-compatible prefixes (e.g., openai, openrouter, groq, gemini),
// Supported protocol families include OpenAI-compatible prefixes (e.g., openai, openrouter, groq),
// Azure OpenAI, Amazon Bedrock, Anthropic (including messages), and various CLI/compatibility shims.
// See the switch on protocol in this function for the authoritative list.
// Returns the provider, the model ID (without protocol prefix), and any error.
@@ -218,7 +218,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
}
return provider, modelID, nil
case "litellm", "lmstudio", "openrouter", "groq", "zhipu", "gemini", "nvidia", "venice",
case "litellm", "lmstudio", "openrouter", "groq", "zhipu", "nvidia", "venice",
"ollama", "moonshot", "shengsuanyun", "deepseek", "cerebras",
"vivgrid", "volcengine", "vllm", "qwen", "qwen-intl", "qwen-international", "dashscope-intl",
"qwen-us", "dashscope-us", "mistral", "avian", "longcat", "modelscope", "novita",
@@ -242,6 +242,24 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
cfg.CustomHeaders,
), modelID, nil
case "gemini":
if cfg.APIKey() == "" && cfg.APIBase == "" {
return nil, "", fmt.Errorf("api_key or api_base is required for gemini protocol (model: %s)", cfg.Model)
}
apiBase := cfg.APIBase
if apiBase == "" {
apiBase = getDefaultAPIBase(protocol)
}
return NewGeminiProvider(
cfg.APIKey(),
apiBase,
cfg.Proxy,
userAgent,
cfg.RequestTimeout,
cfg.ExtraBody,
cfg.CustomHeaders,
), modelID, nil
case "minimax":
// Minimax requires reasoning_split: true in the request body
if cfg.APIKey() == "" && cfg.APIBase == "" {
+56
View File
@@ -434,6 +434,62 @@ func TestCreateProviderFromConfig_Antigravity(t *testing.T) {
}
}
func TestCreateProviderFromConfig_Gemini(t *testing.T) {
cfg := &config.ModelConfig{
ModelName: "test-gemini",
Model: "gemini/gemini-2.5-flash",
}
cfg.SetAPIKey("test-key")
provider, modelID, err := CreateProviderFromConfig(cfg)
if err != nil {
t.Fatalf("CreateProviderFromConfig() error = %v", err)
}
if provider == nil {
t.Fatal("CreateProviderFromConfig() returned nil provider")
}
if modelID != "gemini-2.5-flash" {
t.Errorf("modelID = %q, want %q", modelID, "gemini-2.5-flash")
}
if _, ok := provider.(*GeminiProvider); !ok {
t.Fatalf("expected *GeminiProvider, got %T", provider)
}
}
func TestCreateProviderFromConfig_GeminiMissingAPIKey(t *testing.T) {
cfg := &config.ModelConfig{
ModelName: "test-gemini-no-key",
Model: "gemini/gemini-2.5-flash",
}
_, _, err := CreateProviderFromConfig(cfg)
if err == nil {
t.Fatal("CreateProviderFromConfig() expected error for missing gemini API key")
}
}
func TestCreateProviderFromConfig_GeminiCustomAPIBaseWithoutKey(t *testing.T) {
cfg := &config.ModelConfig{
ModelName: "test-gemini-custom-base",
Model: "gemini/gemini-2.5-flash",
APIBase: "https://proxy.example.com/v1beta",
}
provider, modelID, err := CreateProviderFromConfig(cfg)
if err != nil {
t.Fatalf("CreateProviderFromConfig() error = %v", err)
}
if provider == nil {
t.Fatal("CreateProviderFromConfig() returned nil provider")
}
if modelID != "gemini-2.5-flash" {
t.Errorf("modelID = %q, want %q", modelID, "gemini-2.5-flash")
}
if _, ok := provider.(*GeminiProvider); !ok {
t.Fatalf("expected *GeminiProvider, got %T", provider)
}
}
func TestCreateProviderFromConfig_ClaudeCLI(t *testing.T) {
cfg := &config.ModelConfig{
ModelName: "test-claude-cli",
+758
View File
@@ -0,0 +1,758 @@
package providers
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/sipeed/picoclaw/pkg/providers/common"
)
const (
geminiDefaultAPIBase = "https://generativelanguage.googleapis.com/v1beta"
geminiDefaultModel = "gemini-2.0-flash"
)
type GeminiProvider struct {
apiKey string
apiBase string
httpClient *http.Client
extraBody map[string]any
customHeaders map[string]string
userAgent string
}
func NewGeminiProvider(
apiKey string,
apiBase string,
proxy string,
userAgent string,
requestTimeoutSeconds int,
extraBody map[string]any,
customHeaders map[string]string,
) *GeminiProvider {
if strings.TrimSpace(apiBase) == "" {
apiBase = geminiDefaultAPIBase
}
client := common.NewHTTPClient(proxy)
if requestTimeoutSeconds > 0 {
client.Timeout = time.Duration(requestTimeoutSeconds) * time.Second
}
return &GeminiProvider{
apiKey: strings.TrimSpace(apiKey),
apiBase: strings.TrimRight(strings.TrimSpace(apiBase), "/"),
httpClient: client,
extraBody: cloneAnyMap(extraBody),
customHeaders: cloneStringMap(customHeaders),
userAgent: strings.TrimSpace(userAgent),
}
}
func (p *GeminiProvider) GetDefaultModel() string {
return geminiDefaultModel
}
func (p *GeminiProvider) SupportsThinking() bool {
return true
}
func (p *GeminiProvider) Chat(
ctx context.Context,
messages []Message,
tools []ToolDefinition,
model string,
options map[string]any,
) (*LLMResponse, error) {
if p.apiBase == "" {
return nil, fmt.Errorf("API base not configured")
}
model = normalizeGeminiModel(model)
requestBody := p.buildRequestBody(messages, tools, model, options)
jsonData, err := json.Marshal(requestBody)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
url := fmt.Sprintf("%s/models/%s:generateContent", p.apiBase, model)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(jsonData))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
p.applyHeaders(req)
resp, err := p.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, common.HandleErrorResponse(resp, p.apiBase)
}
var apiResp geminiGenerateContentResponse
if err := json.NewDecoder(resp.Body).Decode(&apiResp); err != nil {
return nil, fmt.Errorf("failed to decode response: %w", err)
}
return parseGeminiResponse(&apiResp), nil
}
func (p *GeminiProvider) ChatStream(
ctx context.Context,
messages []Message,
tools []ToolDefinition,
model string,
options map[string]any,
onChunk func(accumulated string),
) (*LLMResponse, error) {
if p.apiBase == "" {
return nil, fmt.Errorf("API base not configured")
}
model = normalizeGeminiModel(model)
requestBody := p.buildRequestBody(messages, tools, model, options)
jsonData, err := json.Marshal(requestBody)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
url := fmt.Sprintf("%s/models/%s:streamGenerateContent?alt=sse", p.apiBase, model)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(jsonData))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
p.applyHeaders(req)
req.Header.Set("Accept", "text/event-stream")
// Streaming should not use a whole-request timeout; context cancellation is the guard.
streamClient := &http.Client{Transport: p.httpClient.Transport}
resp, err := streamClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, common.HandleErrorResponse(resp, p.apiBase)
}
return parseGeminiStreamResponse(ctx, resp.Body, onChunk)
}
func (p *GeminiProvider) applyHeaders(req *http.Request) {
req.Header.Set("Content-Type", "application/json")
if p.apiKey != "" {
req.Header.Set("x-goog-api-key", p.apiKey)
}
if p.userAgent != "" {
req.Header.Set("User-Agent", p.userAgent)
}
for k, v := range p.customHeaders {
if strings.TrimSpace(k) == "" {
continue
}
req.Header.Set(k, v)
}
}
func (p *GeminiProvider) buildRequestBody(
messages []Message,
tools []ToolDefinition,
model string,
options map[string]any,
) map[string]any {
contents := make([]geminiContent, 0, len(messages))
toolCallNames := make(map[string]string)
var systemInstruction *geminiContent
for _, msg := range messages {
switch msg.Role {
case "system":
if strings.TrimSpace(msg.Content) != "" {
systemInstruction = &geminiContent{Parts: []geminiPart{{Text: msg.Content}}}
}
case "user":
if msg.ToolCallID != "" {
toolName := resolveToolResponseName(msg.ToolCallID, toolCallNames)
contents = append(contents, geminiContent{
Role: "user",
Parts: []geminiPart{{
FunctionResponse: buildGeminiFunctionResponse(toolName, msg.ToolCallID, msg.Content, msg.Media),
}},
})
continue
}
parts := make([]geminiPart, 0, 1+len(msg.Media))
if strings.TrimSpace(msg.Content) != "" {
parts = append(parts, geminiPart{Text: msg.Content})
}
parts = append(parts, buildInlineMediaParts(msg.Media)...)
if len(parts) > 0 {
contents = append(contents, geminiContent{Role: "user", Parts: parts})
}
case "assistant":
content := geminiContent{Role: "model"}
if strings.TrimSpace(msg.Content) != "" {
content.Parts = append(content.Parts, geminiPart{Text: msg.Content})
}
for _, tc := range msg.ToolCalls {
toolName, toolArgs, thoughtSignature := normalizeStoredToolCall(tc)
if toolName == "" {
continue
}
if tc.ID != "" {
toolCallNames[tc.ID] = toolName
}
part := geminiPart{
FunctionCall: &geminiFunctionCall{
Name: toolName,
Args: toolArgs,
ID: tc.ID,
},
}
if thoughtSignature != "" {
part.ThoughtSignature = thoughtSignature
part.ThoughtSignatureSnake = thoughtSignature
}
content.Parts = append(content.Parts, part)
}
if len(content.Parts) > 0 {
contents = append(contents, content)
}
case "tool":
toolName := resolveToolResponseName(msg.ToolCallID, toolCallNames)
contents = append(contents, geminiContent{
Role: "user",
Parts: []geminiPart{{
FunctionResponse: buildGeminiFunctionResponse(toolName, msg.ToolCallID, msg.Content, msg.Media),
}},
})
}
}
body := map[string]any{
"contents": contents,
}
if systemInstruction != nil {
body["systemInstruction"] = systemInstruction
}
if len(tools) > 0 {
funcDecls := make([]geminiFunctionDeclaration, 0, len(tools))
for _, t := range tools {
if t.Type != "function" {
continue
}
funcDecls = append(funcDecls, geminiFunctionDeclaration{
Name: t.Function.Name,
Description: t.Function.Description,
Parameters: sanitizeSchemaForGemini(t.Function.Parameters),
})
}
if len(funcDecls) > 0 {
body["tools"] = []geminiTool{{FunctionDeclarations: funcDecls}}
}
}
generationConfig := make(map[string]any)
if val, ok := options["max_tokens"]; ok {
if maxTokens, ok := val.(int); ok && maxTokens > 0 {
generationConfig["maxOutputTokens"] = maxTokens
} else if maxTokens, ok := val.(float64); ok && maxTokens > 0 {
generationConfig["maxOutputTokens"] = int(maxTokens)
}
}
if temp, ok := options["temperature"].(float64); ok {
generationConfig["temperature"] = temp
}
if thinkingConfig := buildGeminiThinkingConfig(model, options); len(thinkingConfig) > 0 {
generationConfig["thinkingConfig"] = thinkingConfig
}
if len(generationConfig) > 0 {
body["generationConfig"] = generationConfig
}
for k, v := range p.extraBody {
body[k] = v
}
return body
}
func normalizeGeminiModel(model string) string {
model = strings.TrimSpace(model)
model = strings.TrimPrefix(model, "models/")
if strings.Contains(model, "/") {
_, modelID := ExtractProtocol(model)
if modelID != "" {
return modelID
}
}
if model == "" {
return geminiDefaultModel
}
return model
}
func mapGeminiThinkingLevel(level string) string {
switch strings.ToLower(strings.TrimSpace(level)) {
case "minimal", "off":
return "minimal"
case "low":
return "low"
case "medium":
return "medium"
case "high", "xhigh", "adaptive":
return "high"
default:
return ""
}
}
func buildGeminiThinkingConfig(model string, options map[string]any) map[string]any {
if !geminiModelSupportsThinkingConfig(model) {
return nil
}
config := map[string]any{"includeThoughts": true}
rawLevel, _ := options["thinking_level"].(string)
rawLevel = strings.ToLower(strings.TrimSpace(rawLevel))
if isGemini25Model(model) {
if budget, ok := mapGeminiThinkingBudget(rawLevel, model); ok {
config["thinkingBudget"] = budget
}
return config
}
if thinkingLevel := mapGeminiThinkingLevel(rawLevel); thinkingLevel != "" {
config["thinkingLevel"] = thinkingLevel
}
return config
}
func geminiModelSupportsThinkingConfig(model string) bool {
lowerModel := strings.ToLower(strings.TrimSpace(model))
return strings.Contains(lowerModel, "gemini-3") || isGemini25Model(lowerModel)
}
func isGemini25Model(model string) bool {
lowerModel := strings.ToLower(strings.TrimSpace(model))
return strings.Contains(lowerModel, "gemini-2.5") || strings.Contains(lowerModel, "gemini-25")
}
func mapGeminiThinkingBudget(level string, model string) (int, bool) {
level = strings.ToLower(strings.TrimSpace(level))
if level == "" {
return 0, false
}
switch level {
case "adaptive":
return -1, true
case "minimal":
if strings.Contains(strings.ToLower(model), "pro") {
return 128, true
}
return 0, true
case "off":
if strings.Contains(strings.ToLower(model), "pro") {
// Gemini 2.5 Pro cannot disable thinking; use the lowest supported budget.
return 128, true
}
return 0, true
case "low":
return 1024, true
case "medium":
return 4096, true
case "high":
return 8192, true
case "xhigh":
return 16384, true
default:
return 0, false
}
}
func parseGeminiResponse(resp *geminiGenerateContentResponse) *LLMResponse {
contentParts := make([]string, 0)
reasoningParts := make([]string, 0)
toolCalls := make([]ToolCall, 0)
finishReason := ""
for _, candidate := range resp.Candidates {
for _, part := range candidate.Content.Parts {
if part.Text != "" {
if part.Thought {
reasoningParts = append(reasoningParts, part.Text)
} else {
contentParts = append(contentParts, part.Text)
}
}
if part.FunctionCall != nil {
toolCalls = append(toolCalls, buildGeminiToolCall(part))
}
}
if candidate.FinishReason != "" {
finishReason = candidate.FinishReason
}
}
var usage *UsageInfo
if resp.UsageMetadata.TotalTokenCount > 0 {
usage = &UsageInfo{
PromptTokens: resp.UsageMetadata.PromptTokenCount,
CompletionTokens: resp.UsageMetadata.CandidatesTokenCount,
TotalTokens: resp.UsageMetadata.TotalTokenCount,
}
}
return &LLMResponse{
Content: strings.Join(contentParts, ""),
ReasoningContent: strings.Join(reasoningParts, ""),
ToolCalls: toolCalls,
FinishReason: normalizeGeminiFinishReason(finishReason, len(toolCalls)),
Usage: usage,
}
}
func parseGeminiStreamResponse(
ctx context.Context,
reader io.Reader,
onChunk func(accumulated string),
) (*LLMResponse, error) {
var contentBuilder strings.Builder
var reasoningBuilder strings.Builder
var finishReason string
var usage *UsageInfo
toolCallsByID := make(map[string]ToolCall)
toolCallOrder := make([]string, 0)
fallbackIndex := 0
scanner := bufio.NewScanner(reader)
scanner.Buffer(make([]byte, 0, 1024*1024), 10*1024*1024)
for scanner.Scan() {
if err := ctx.Err(); err != nil {
return nil, err
}
line := scanner.Text()
if !strings.HasPrefix(line, "data: ") {
continue
}
data := strings.TrimPrefix(line, "data: ")
if data == "[DONE]" {
break
}
var chunk geminiGenerateContentResponse
if err := json.Unmarshal([]byte(data), &chunk); err != nil {
continue
}
for _, candidate := range chunk.Candidates {
for _, part := range candidate.Content.Parts {
if part.Text != "" {
if part.Thought {
reasoningBuilder.WriteString(part.Text)
} else {
contentBuilder.WriteString(part.Text)
if onChunk != nil {
onChunk(contentBuilder.String())
}
}
}
if part.FunctionCall != nil {
tc := buildGeminiToolCall(part)
key := tc.ID
if strings.TrimSpace(key) == "" {
fallbackIndex++
key = fmt.Sprintf("%s#%d", tc.Name, fallbackIndex)
tc.ID = key
}
if _, exists := toolCallsByID[key]; !exists {
toolCallOrder = append(toolCallOrder, key)
}
toolCallsByID[key] = tc
}
}
if candidate.FinishReason != "" {
finishReason = candidate.FinishReason
}
}
if chunk.UsageMetadata.TotalTokenCount > 0 {
usage = &UsageInfo{
PromptTokens: chunk.UsageMetadata.PromptTokenCount,
CompletionTokens: chunk.UsageMetadata.CandidatesTokenCount,
TotalTokens: chunk.UsageMetadata.TotalTokenCount,
}
}
}
if err := scanner.Err(); err != nil {
return nil, fmt.Errorf("streaming read error: %w", err)
}
toolCalls := make([]ToolCall, 0, len(toolCallOrder))
for _, key := range toolCallOrder {
toolCalls = append(toolCalls, toolCallsByID[key])
}
return &LLMResponse{
Content: contentBuilder.String(),
ReasoningContent: reasoningBuilder.String(),
ToolCalls: toolCalls,
FinishReason: normalizeGeminiFinishReason(finishReason, len(toolCalls)),
Usage: usage,
}, nil
}
func normalizeGeminiFinishReason(reason string, toolCalls int) string {
if toolCalls > 0 {
return "tool_calls"
}
switch strings.ToUpper(strings.TrimSpace(reason)) {
case "MAX_TOKENS":
return "length"
case "", "STOP":
return "stop"
default:
return strings.ToLower(strings.TrimSpace(reason))
}
}
func buildGeminiToolCall(part geminiPart) ToolCall {
if part.FunctionCall == nil {
return ToolCall{}
}
args := part.FunctionCall.Args
if args == nil {
args = make(map[string]any)
}
argsJSON, _ := json.Marshal(args)
thoughtSignature := extractPartThoughtSignature(part.ThoughtSignature, part.ThoughtSignatureSnake)
toolCall := ToolCall{
ID: part.FunctionCall.ID,
Name: part.FunctionCall.Name,
Arguments: args,
ThoughtSignature: thoughtSignature,
Function: &FunctionCall{
Name: part.FunctionCall.Name,
Arguments: string(argsJSON),
ThoughtSignature: thoughtSignature,
},
}
if thoughtSignature != "" {
toolCall.ExtraContent = &ExtraContent{
Google: &GoogleExtra{ThoughtSignature: thoughtSignature},
}
}
if strings.TrimSpace(toolCall.ID) == "" {
toolCall.ID = fmt.Sprintf("call_%s_%d", toolCall.Name, time.Now().UnixNano())
}
return toolCall
}
func buildInlineMediaParts(media []string) []geminiPart {
parts := make([]geminiPart, 0, len(media))
for _, mediaURL := range media {
mimeType, data, ok := parseBase64DataURL(mediaURL)
if !ok {
continue
}
parts = append(parts, geminiPart{
InlineData: &geminiInlineData{
MIMEType: mimeType,
Data: data,
},
})
}
return parts
}
func buildGeminiFunctionResponse(
toolName string,
toolCallID string,
result string,
media []string,
) *geminiFunctionResponse {
response := &geminiFunctionResponse{
ID: toolCallID,
Name: toolName,
Response: map[string]any{
"result": result,
},
}
if parts := buildFunctionResponseMediaParts(media); len(parts) > 0 {
response.Parts = parts
}
return response
}
func buildFunctionResponseMediaParts(media []string) []geminiFunctionResponsePart {
parts := make([]geminiFunctionResponsePart, 0, len(media))
for i, mediaURL := range media {
mimeType, data, ok := parseBase64DataURL(mediaURL)
if !ok {
continue
}
parts = append(parts, geminiFunctionResponsePart{
InlineData: &geminiInlineData{
MIMEType: mimeType,
Data: data,
DisplayName: defaultFunctionResponseDisplayName(mimeType, i+1),
},
})
}
return parts
}
func defaultFunctionResponseDisplayName(mimeType string, index int) string {
suffix := "bin"
switch strings.ToLower(strings.TrimSpace(mimeType)) {
case "image/png":
suffix = "png"
case "image/jpeg":
suffix = "jpg"
case "image/webp":
suffix = "webp"
case "application/pdf":
suffix = "pdf"
case "text/plain":
suffix = "txt"
}
return fmt.Sprintf("attachment-%d.%s", index, suffix)
}
func parseBase64DataURL(mediaURL string) (mimeType string, data string, ok bool) {
if !strings.HasPrefix(mediaURL, "data:") {
return "", "", false
}
payload := strings.TrimPrefix(mediaURL, "data:")
header, data, found := strings.Cut(payload, ",")
if !found {
return "", "", false
}
mimeType, params, _ := strings.Cut(header, ";")
mimeType = strings.TrimSpace(mimeType)
data = strings.TrimSpace(data)
if mimeType == "" || data == "" {
return "", "", false
}
if !strings.Contains(strings.ToLower(params), "base64") {
return "", "", false
}
return mimeType, data, true
}
func cloneAnyMap(in map[string]any) map[string]any {
if len(in) == 0 {
return nil
}
out := make(map[string]any, len(in))
for k, v := range in {
out[k] = v
}
return out
}
func cloneStringMap(in map[string]string) map[string]string {
if len(in) == 0 {
return nil
}
out := make(map[string]string, len(in))
for k, v := range in {
out[k] = v
}
return out
}
type geminiGenerateContentResponse struct {
Candidates []struct {
Content struct {
Role string `json:"role"`
Parts []geminiPart `json:"parts"`
} `json:"content"`
FinishReason string `json:"finishReason"`
} `json:"candidates"`
UsageMetadata struct {
PromptTokenCount int `json:"promptTokenCount"`
CandidatesTokenCount int `json:"candidatesTokenCount"`
TotalTokenCount int `json:"totalTokenCount"`
} `json:"usageMetadata"`
}
type geminiContent struct {
Role string `json:"role,omitempty"`
Parts []geminiPart `json:"parts"`
}
type geminiPart struct {
Text string `json:"text,omitempty"`
Thought bool `json:"thought,omitempty"`
ThoughtSignature string `json:"thoughtSignature,omitempty"`
ThoughtSignatureSnake string `json:"thought_signature,omitempty"`
InlineData *geminiInlineData `json:"inlineData,omitempty"`
FunctionCall *geminiFunctionCall `json:"functionCall,omitempty"`
FunctionResponse *geminiFunctionResponse `json:"functionResponse,omitempty"`
}
type geminiInlineData struct {
MIMEType string `json:"mimeType"`
Data string `json:"data"`
DisplayName string `json:"displayName,omitempty"`
}
type geminiFunctionCall struct {
ID string `json:"id,omitempty"`
Name string `json:"name"`
Args map[string]any `json:"args,omitempty"`
}
type geminiFunctionResponse struct {
ID string `json:"id,omitempty"`
Name string `json:"name"`
Response map[string]any `json:"response"`
Parts []geminiFunctionResponsePart `json:"parts,omitempty"`
}
type geminiFunctionResponsePart struct {
InlineData *geminiInlineData `json:"inlineData,omitempty"`
}
type geminiTool struct {
FunctionDeclarations []geminiFunctionDeclaration `json:"functionDeclarations"`
}
type geminiFunctionDeclaration struct {
Name string `json:"name"`
Description string `json:"description,omitempty"`
Parameters any `json:"parameters,omitempty"`
}
+440
View File
@@ -0,0 +1,440 @@
package providers
import (
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
)
func TestGeminiProvider_ChatSeparatesThoughtAndToolCall(t *testing.T) {
var capturedBody map[string]any
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
t.Fatalf("method = %s, want POST", r.Method)
}
if !strings.Contains(r.URL.Path, ":generateContent") {
t.Fatalf("path = %s, expected generateContent endpoint", r.URL.Path)
}
if got := r.Header.Get("x-goog-api-key"); got != "test-key" {
t.Fatalf("x-goog-api-key = %q, want %q", got, "test-key")
}
if err := json.NewDecoder(r.Body).Decode(&capturedBody); err != nil {
t.Fatalf("decode request body: %v", err)
}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]any{
"candidates": []any{
map[string]any{
"content": map[string]any{
"role": "model",
"parts": []any{
map[string]any{"text": "hidden", "thought": true},
map[string]any{"text": "visible"},
map[string]any{
"functionCall": map[string]any{
"id": "call_1",
"name": "search",
"args": map[string]any{"q": "hi"},
},
"thoughtSignature": "sig-1",
},
},
},
"finishReason": "STOP",
},
},
"usageMetadata": map[string]any{
"promptTokenCount": 2,
"candidatesTokenCount": 3,
"totalTokenCount": 5,
},
})
}))
defer server.Close()
provider := NewGeminiProvider("test-key", server.URL, "", "picoclaw-test", 0, nil, nil)
resp, err := provider.Chat(
t.Context(),
[]Message{{Role: "user", Content: "hello"}},
nil,
"gemini-3-flash-preview",
map[string]any{"thinking_level": "high"},
)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
if resp.Content != "visible" {
t.Fatalf("Content = %q, want %q", resp.Content, "visible")
}
if resp.ReasoningContent != "hidden" {
t.Fatalf("ReasoningContent = %q, want %q", resp.ReasoningContent, "hidden")
}
if resp.FinishReason != "tool_calls" {
t.Fatalf("FinishReason = %q, want %q", resp.FinishReason, "tool_calls")
}
if resp.Usage == nil || resp.Usage.TotalTokens != 5 {
t.Fatalf("Usage = %#v, expected total tokens = 5", resp.Usage)
}
if len(resp.ToolCalls) != 1 {
t.Fatalf("ToolCalls len = %d, want 1", len(resp.ToolCalls))
}
if resp.ToolCalls[0].ID != "call_1" {
t.Fatalf("ToolCall ID = %q, want %q", resp.ToolCalls[0].ID, "call_1")
}
if resp.ToolCalls[0].Name != "search" {
t.Fatalf("ToolCall Name = %q, want %q", resp.ToolCalls[0].Name, "search")
}
if resp.ToolCalls[0].ThoughtSignature != "sig-1" {
t.Fatalf("ToolCall ThoughtSignature = %q, want %q", resp.ToolCalls[0].ThoughtSignature, "sig-1")
}
if resp.ToolCalls[0].Function == nil || !strings.Contains(resp.ToolCalls[0].Function.Arguments, `"q":"hi"`) {
t.Fatalf("ToolCall Function arguments = %#v, want q=hi", resp.ToolCalls[0].Function)
}
generationConfig, ok := capturedBody["generationConfig"].(map[string]any)
if !ok {
t.Fatalf("request missing generationConfig: %#v", capturedBody)
}
thinkingConfig, ok := generationConfig["thinkingConfig"].(map[string]any)
if !ok {
t.Fatalf("request missing thinkingConfig: %#v", generationConfig)
}
if includeThoughts, ok := thinkingConfig["includeThoughts"].(bool); !ok || !includeThoughts {
t.Fatalf("thinkingConfig.includeThoughts = %#v, want true", thinkingConfig["includeThoughts"])
}
if got := thinkingConfig["thinkingLevel"]; got != "high" {
t.Fatalf("thinkingConfig.thinkingLevel = %#v, want %q", got, "high")
}
}
func TestGeminiProvider_ChatStreamParsesThoughtTextAndToolCalls(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !strings.Contains(r.URL.Path, ":streamGenerateContent") {
t.Fatalf("path = %s, expected streamGenerateContent endpoint", r.URL.Path)
}
if got := r.URL.Query().Get("alt"); got != "sse" {
t.Fatalf("alt query = %q, want %q", got, "sse")
}
w.Header().Set("Content-Type", "text/event-stream")
flusher, ok := w.(http.Flusher)
if !ok {
t.Fatal("response writer is not flushable")
}
chunks := []map[string]any{
{
"candidates": []any{map[string]any{
"content": map[string]any{
"parts": []any{
map[string]any{"text": "think ", "thought": true},
map[string]any{"text": "Hello "},
},
},
}},
},
{
"candidates": []any{map[string]any{
"content": map[string]any{
"parts": []any{
map[string]any{"text": "World"},
map[string]any{
"functionCall": map[string]any{
"id": "call_stream",
"name": "search",
"args": map[string]any{"q": "stream"},
},
},
},
},
"finishReason": "STOP",
}},
"usageMetadata": map[string]any{
"promptTokenCount": 1,
"candidatesTokenCount": 2,
"totalTokenCount": 3,
},
},
}
for _, chunk := range chunks {
raw, err := json.Marshal(chunk)
if err != nil {
t.Fatalf("marshal chunk: %v", err)
}
if _, err := fmt.Fprintf(w, "data: %s\n\n", raw); err != nil {
t.Fatalf("write chunk: %v", err)
}
flusher.Flush()
}
_, _ = fmt.Fprint(w, "data: [DONE]\n\n")
flusher.Flush()
}))
defer server.Close()
provider := NewGeminiProvider("test-key", server.URL, "", "", 0, nil, nil)
updates := make([]string, 0)
resp, err := provider.ChatStream(
t.Context(),
[]Message{{Role: "user", Content: "hello"}},
nil,
"gemini-2.5-flash",
nil,
func(accumulated string) {
updates = append(updates, accumulated)
},
)
if err != nil {
t.Fatalf("ChatStream() error = %v", err)
}
if resp.Content != "Hello World" {
t.Fatalf("Content = %q, want %q", resp.Content, "Hello World")
}
if resp.ReasoningContent != "think " {
t.Fatalf("ReasoningContent = %q, want %q", resp.ReasoningContent, "think ")
}
if len(resp.ToolCalls) != 1 || resp.ToolCalls[0].ID != "call_stream" {
t.Fatalf("ToolCalls = %#v, want single call_stream", resp.ToolCalls)
}
if resp.FinishReason != "tool_calls" {
t.Fatalf("FinishReason = %q, want %q", resp.FinishReason, "tool_calls")
}
if resp.Usage == nil || resp.Usage.TotalTokens != 3 {
t.Fatalf("Usage = %#v, expected total tokens = 3", resp.Usage)
}
if len(updates) < 2 || updates[len(updates)-1] != "Hello World" {
t.Fatalf("stream updates = %#v, expected final accumulated text", updates)
}
}
func TestGeminiProvider_BuildRequestBodyIncludesMediaAndThinkingConfig(t *testing.T) {
provider := NewGeminiProvider("test-key", "https://example.com/v1beta", "", "", 0, nil, nil)
body := provider.buildRequestBody(
[]Message{{
Role: "user",
Content: "analyze attachments",
Media: []string{
"data:application/pdf;base64,UEZERGF0YQ==",
"data:image/png;base64,aW1hZ2VEYXRh",
},
}},
nil,
"gemini-3-flash-preview",
map[string]any{
"thinking_level": "low",
"max_tokens": 128,
"temperature": 0.2,
},
)
contents, ok := body["contents"].([]geminiContent)
if !ok || len(contents) != 1 {
t.Fatalf("contents = %#v, want one gemini content", body["contents"])
}
parts := contents[0].Parts
mimeSet := map[string]bool{}
for _, part := range parts {
if part.InlineData != nil {
mimeSet[part.InlineData.MIMEType] = true
}
}
if !mimeSet["application/pdf"] {
t.Fatalf("inline media missing application/pdf: %#v", parts)
}
if !mimeSet["image/png"] {
t.Fatalf("inline media missing image/png: %#v", parts)
}
generationConfig, ok := body["generationConfig"].(map[string]any)
if !ok {
t.Fatalf("generationConfig = %#v, want map", body["generationConfig"])
}
if got := generationConfig["maxOutputTokens"]; got != 128 {
t.Fatalf("maxOutputTokens = %#v, want 128", got)
}
if got := generationConfig["temperature"]; got != 0.2 {
t.Fatalf("temperature = %#v, want 0.2", got)
}
thinkingConfig, ok := generationConfig["thinkingConfig"].(map[string]any)
if !ok {
t.Fatalf("thinkingConfig = %#v, want map", generationConfig["thinkingConfig"])
}
if includeThoughts, ok := thinkingConfig["includeThoughts"].(bool); !ok || !includeThoughts {
t.Fatalf("includeThoughts = %#v, want true", thinkingConfig["includeThoughts"])
}
if got := thinkingConfig["thinkingLevel"]; got != "low" {
t.Fatalf("thinkingLevel = %#v, want %q", got, "low")
}
}
func TestGeminiProvider_BuildRequestBody_UsesThinkingBudgetForGemini25(t *testing.T) {
provider := NewGeminiProvider("test-key", "https://example.com/v1beta", "", "", 0, nil, nil)
body := provider.buildRequestBody(
[]Message{{Role: "user", Content: "hello"}},
nil,
"gemini-2.5-flash",
map[string]any{"thinking_level": "medium"},
)
generationConfig, ok := body["generationConfig"].(map[string]any)
if !ok {
t.Fatalf("generationConfig = %#v, want map", body["generationConfig"])
}
thinkingConfig, ok := generationConfig["thinkingConfig"].(map[string]any)
if !ok {
t.Fatalf("thinkingConfig = %#v, want map", generationConfig["thinkingConfig"])
}
if got := thinkingConfig["thinkingBudget"]; got != 4096 {
t.Fatalf("thinkingBudget = %#v, want 4096", got)
}
if _, hasLevel := thinkingConfig["thinkingLevel"]; hasLevel {
t.Fatalf("thinkingLevel should not be set for Gemini 2.5: %#v", thinkingConfig)
}
}
func TestGeminiProvider_BuildRequestBody_OmitsThinkingConfigForGemini20(t *testing.T) {
provider := NewGeminiProvider("test-key", "https://example.com/v1beta", "", "", 0, nil, nil)
body := provider.buildRequestBody(
[]Message{{Role: "user", Content: "hello"}},
nil,
"gemini-2.0-flash-exp",
map[string]any{"thinking_level": "high"},
)
if _, ok := body["generationConfig"]; ok {
t.Fatalf("generationConfig should be omitted for Gemini 2.0 when only thinking_level is set: %#v", body)
}
}
func TestGeminiProvider_BuildRequestBody_PreservesToolResponseMedia(t *testing.T) {
provider := NewGeminiProvider("test-key", "https://example.com/v1beta", "", "", 0, nil, nil)
body := provider.buildRequestBody(
[]Message{
{
Role: "assistant",
ToolCalls: []ToolCall{{
ID: "call_1",
Name: "load_image",
Arguments: map[string]any{"path": "demo.png"},
}},
},
{
Role: "tool",
ToolCallID: "call_1",
Content: "tool result",
Media: []string{
"data:image/png;base64,aW1hZ2VEYXRh",
"data:application/pdf;base64,UEZERGF0YQ==",
},
},
},
nil,
"gemini-3-flash-preview",
nil,
)
contents, ok := body["contents"].([]geminiContent)
if !ok || len(contents) != 2 {
t.Fatalf("contents = %#v, want two content entries", body["contents"])
}
parts := contents[1].Parts
if len(parts) != 1 || parts[0].FunctionResponse == nil {
t.Fatalf("tool response part = %#v, want functionResponse", parts)
}
response := parts[0].FunctionResponse
if response.Name != "load_image" {
t.Fatalf("functionResponse.Name = %q, want %q", response.Name, "load_image")
}
if response.Response["result"] != "tool result" {
t.Fatalf("functionResponse.Response = %#v, want result=tool result", response.Response)
}
if len(response.Parts) != 2 {
t.Fatalf("functionResponse.Parts len = %d, want 2", len(response.Parts))
}
}
func TestGeminiProvider_ChatAllowsCustomAuthHeaderWithoutAPIKey(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if got := r.Header.Get("Authorization"); got != "Bearer test-token" {
t.Fatalf("Authorization = %q, want %q", got, "Bearer test-token")
}
if got := r.Header.Get("x-goog-api-key"); got != "" {
t.Fatalf("x-goog-api-key = %q, want empty", got)
}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]any{
"candidates": []any{
map[string]any{
"content": map[string]any{
"parts": []any{map[string]any{"text": "ok"}},
},
"finishReason": "STOP",
},
},
})
}))
defer server.Close()
provider := NewGeminiProvider(
"",
server.URL,
"",
"",
0,
nil,
map[string]string{"Authorization": "Bearer test-token"},
)
resp, err := provider.Chat(
t.Context(),
[]Message{{Role: "user", Content: "hello"}},
nil,
"gemini-2.5-flash",
nil,
)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
if resp.Content != "ok" {
t.Fatalf("Content = %q, want %q", resp.Content, "ok")
}
}
func TestGeminiProvider_ChatAllowsMissingAPIKeyForCustomAPIBase(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if got := r.Header.Get("x-goog-api-key"); got != "" {
t.Fatalf("x-goog-api-key = %q, want empty", got)
}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]any{
"candidates": []any{
map[string]any{
"content": map[string]any{"parts": []any{map[string]any{"text": "ok"}}},
"finishReason": "STOP",
},
},
})
}))
defer server.Close()
provider := NewGeminiProvider("", server.URL, "", "", 0, nil, nil)
resp, err := provider.Chat(
t.Context(),
[]Message{{Role: "user", Content: "hello"}},
nil,
"gemini-2.5-flash",
nil,
)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
if resp.Content != "ok" {
t.Fatalf("Content = %q, want %q", resp.Content, "ok")
}
}
+2 -3
View File
@@ -8,6 +8,7 @@ import (
"fmt"
"io"
"log"
"maps"
"net/http"
"net/url"
"strings"
@@ -181,9 +182,7 @@ func (p *Provider) buildRequestBody(
// Merge extra body fields configured per-provider/model.
// These are injected last so they take precedence over defaults.
for k, v := range p.extraBody {
requestBody[k] = v
}
maps.Copy(requestBody, p.extraBody)
return requestBody
}
+13
View File
@@ -281,6 +281,12 @@ func visibleSessionMessages(messages []providers.Message, toolFeedbackMaxArgsLen
}
case "assistant":
// Reasoning-only assistant messages are transient display artifacts and
// should not be restored from session history.
if assistantMessageTransientThought(msg) {
continue
}
toolSummaryMessages := visibleAssistantToolSummaryMessages(msg.ToolCalls, toolFeedbackMaxArgsLength)
if len(toolSummaryMessages) > 0 {
transcript = append(transcript, toolSummaryMessages...)
@@ -309,6 +315,13 @@ func visibleSessionMessages(messages []providers.Message, toolFeedbackMaxArgsLen
return transcript
}
func assistantMessageTransientThought(msg providers.Message) bool {
return strings.TrimSpace(msg.Content) == "" &&
strings.TrimSpace(msg.ReasoningContent) != "" &&
len(msg.ToolCalls) == 0 &&
len(msg.Media) == 0
}
func assistantMessageInternalOnly(msg providers.Message) bool {
return strings.TrimSpace(msg.Content) == handledToolResponseSummaryText
}
+53
View File
@@ -218,6 +218,59 @@ func TestHandleGetSession_JSONLStorage(t *testing.T) {
}
}
func TestHandleGetSession_OmitsTransientThoughtMessages(t *testing.T) {
configPath, cleanup := setupOAuthTestEnv(t)
defer cleanup()
dir := sessionsTestDir(t, configPath)
store, err := memory.NewJSONLStore(dir)
if err != nil {
t.Fatalf("NewJSONLStore() error = %v", err)
}
sessionKey := picoSessionPrefix + "detail-transient-thought"
for _, msg := range []providers.Message{
{Role: "user", Content: "hello"},
{Role: "assistant", ReasoningContent: "internal chain of thought"},
{Role: "assistant", Content: "final visible answer"},
} {
if err := store.AddFullMessage(nil, sessionKey, msg); err != nil {
t.Fatalf("AddFullMessage() error = %v", err)
}
}
h := NewHandler(configPath)
mux := http.NewServeMux()
h.RegisterRoutes(mux)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/sessions/detail-transient-thought", nil)
mux.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String())
}
var resp struct {
Messages []struct {
Role string `json:"role"`
Content string `json:"content"`
} `json:"messages"`
}
if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil {
t.Fatalf("Unmarshal() error = %v", err)
}
if len(resp.Messages) != 2 {
t.Fatalf("len(resp.Messages) = %d, want 2", len(resp.Messages))
}
if resp.Messages[0].Role != "user" || resp.Messages[0].Content != "hello" {
t.Fatalf("first message = %#v, want user/hello", resp.Messages[0])
}
if resp.Messages[1].Role != "assistant" || resp.Messages[1].Content != "final visible answer" {
t.Fatalf("second message = %#v, want assistant/final visible answer", resp.Messages[1])
}
}
func TestHandleGetSession_ReconstructsVisibleMessageToolOutput(t *testing.T) {
configPath, cleanup := setupOAuthTestEnv(t)
defer cleanup()