diff --git a/pkg/providers/factory_provider.go b/pkg/providers/factory_provider.go index f13dc646c..ab68b326a 100644 --- a/pkg/providers/factory_provider.go +++ b/pkg/providers/factory_provider.go @@ -114,7 +114,7 @@ func ResolveAPIBase(cfg *config.ModelConfig) string { // CreateProviderFromConfig creates a provider based on the ModelConfig. // It uses the protocol prefix in the Model field to determine which provider to create. -// Supported protocol families include OpenAI-compatible prefixes (e.g., openai, openrouter, groq, gemini), +// Supported protocol families include OpenAI-compatible prefixes (e.g., openai, openrouter, groq), // Azure OpenAI, Amazon Bedrock, Anthropic (including messages), and various CLI/compatibility shims. // See the switch on protocol in this function for the authoritative list. // Returns the provider, the model ID (without protocol prefix), and any error. @@ -218,7 +218,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err } return provider, modelID, nil - case "litellm", "lmstudio", "openrouter", "groq", "zhipu", "gemini", "nvidia", "venice", + case "litellm", "lmstudio", "openrouter", "groq", "zhipu", "nvidia", "venice", "ollama", "moonshot", "shengsuanyun", "deepseek", "cerebras", "vivgrid", "volcengine", "vllm", "qwen", "qwen-intl", "qwen-international", "dashscope-intl", "qwen-us", "dashscope-us", "mistral", "avian", "longcat", "modelscope", "novita", @@ -242,6 +242,24 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err cfg.CustomHeaders, ), modelID, nil + case "gemini": + if cfg.APIKey() == "" && cfg.APIBase == "" { + return nil, "", fmt.Errorf("api_key or api_base is required for gemini protocol (model: %s)", cfg.Model) + } + apiBase := cfg.APIBase + if apiBase == "" { + apiBase = getDefaultAPIBase(protocol) + } + return NewGeminiProvider( + cfg.APIKey(), + apiBase, + cfg.Proxy, + userAgent, + cfg.RequestTimeout, + cfg.ExtraBody, + cfg.CustomHeaders, + ), modelID, nil + case "minimax": // Minimax requires reasoning_split: true in the request body if cfg.APIKey() == "" && cfg.APIBase == "" { diff --git a/pkg/providers/factory_provider_test.go b/pkg/providers/factory_provider_test.go index c362463ae..20cdd8a30 100644 --- a/pkg/providers/factory_provider_test.go +++ b/pkg/providers/factory_provider_test.go @@ -434,6 +434,62 @@ func TestCreateProviderFromConfig_Antigravity(t *testing.T) { } } +func TestCreateProviderFromConfig_Gemini(t *testing.T) { + cfg := &config.ModelConfig{ + ModelName: "test-gemini", + Model: "gemini/gemini-2.5-flash", + } + cfg.SetAPIKey("test-key") + + provider, modelID, err := CreateProviderFromConfig(cfg) + if err != nil { + t.Fatalf("CreateProviderFromConfig() error = %v", err) + } + if provider == nil { + t.Fatal("CreateProviderFromConfig() returned nil provider") + } + if modelID != "gemini-2.5-flash" { + t.Errorf("modelID = %q, want %q", modelID, "gemini-2.5-flash") + } + if _, ok := provider.(*GeminiProvider); !ok { + t.Fatalf("expected *GeminiProvider, got %T", provider) + } +} + +func TestCreateProviderFromConfig_GeminiMissingAPIKey(t *testing.T) { + cfg := &config.ModelConfig{ + ModelName: "test-gemini-no-key", + Model: "gemini/gemini-2.5-flash", + } + + _, _, err := CreateProviderFromConfig(cfg) + if err == nil { + t.Fatal("CreateProviderFromConfig() expected error for missing gemini API key") + } +} + +func TestCreateProviderFromConfig_GeminiCustomAPIBaseWithoutKey(t *testing.T) { + cfg := &config.ModelConfig{ + ModelName: "test-gemini-custom-base", + Model: "gemini/gemini-2.5-flash", + APIBase: "https://proxy.example.com/v1beta", + } + + provider, modelID, err := CreateProviderFromConfig(cfg) + if err != nil { + t.Fatalf("CreateProviderFromConfig() error = %v", err) + } + if provider == nil { + t.Fatal("CreateProviderFromConfig() returned nil provider") + } + if modelID != "gemini-2.5-flash" { + t.Errorf("modelID = %q, want %q", modelID, "gemini-2.5-flash") + } + if _, ok := provider.(*GeminiProvider); !ok { + t.Fatalf("expected *GeminiProvider, got %T", provider) + } +} + func TestCreateProviderFromConfig_ClaudeCLI(t *testing.T) { cfg := &config.ModelConfig{ ModelName: "test-claude-cli", diff --git a/pkg/providers/gemini_provider.go b/pkg/providers/gemini_provider.go new file mode 100644 index 000000000..b3042fcd7 --- /dev/null +++ b/pkg/providers/gemini_provider.go @@ -0,0 +1,758 @@ +package providers + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" + + "github.com/sipeed/picoclaw/pkg/providers/common" +) + +const ( + geminiDefaultAPIBase = "https://generativelanguage.googleapis.com/v1beta" + geminiDefaultModel = "gemini-2.0-flash" +) + +type GeminiProvider struct { + apiKey string + apiBase string + httpClient *http.Client + extraBody map[string]any + customHeaders map[string]string + userAgent string +} + +func NewGeminiProvider( + apiKey string, + apiBase string, + proxy string, + userAgent string, + requestTimeoutSeconds int, + extraBody map[string]any, + customHeaders map[string]string, +) *GeminiProvider { + if strings.TrimSpace(apiBase) == "" { + apiBase = geminiDefaultAPIBase + } + client := common.NewHTTPClient(proxy) + if requestTimeoutSeconds > 0 { + client.Timeout = time.Duration(requestTimeoutSeconds) * time.Second + } + + return &GeminiProvider{ + apiKey: strings.TrimSpace(apiKey), + apiBase: strings.TrimRight(strings.TrimSpace(apiBase), "/"), + httpClient: client, + extraBody: cloneAnyMap(extraBody), + customHeaders: cloneStringMap(customHeaders), + userAgent: strings.TrimSpace(userAgent), + } +} + +func (p *GeminiProvider) GetDefaultModel() string { + return geminiDefaultModel +} + +func (p *GeminiProvider) SupportsThinking() bool { + return true +} + +func (p *GeminiProvider) Chat( + ctx context.Context, + messages []Message, + tools []ToolDefinition, + model string, + options map[string]any, +) (*LLMResponse, error) { + if p.apiBase == "" { + return nil, fmt.Errorf("API base not configured") + } + + model = normalizeGeminiModel(model) + requestBody := p.buildRequestBody(messages, tools, model, options) + jsonData, err := json.Marshal(requestBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + url := fmt.Sprintf("%s/models/%s:generateContent", p.apiBase, model) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + p.applyHeaders(req) + + resp, err := p.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, common.HandleErrorResponse(resp, p.apiBase) + } + + var apiResp geminiGenerateContentResponse + if err := json.NewDecoder(resp.Body).Decode(&apiResp); err != nil { + return nil, fmt.Errorf("failed to decode response: %w", err) + } + + return parseGeminiResponse(&apiResp), nil +} + +func (p *GeminiProvider) ChatStream( + ctx context.Context, + messages []Message, + tools []ToolDefinition, + model string, + options map[string]any, + onChunk func(accumulated string), +) (*LLMResponse, error) { + if p.apiBase == "" { + return nil, fmt.Errorf("API base not configured") + } + + model = normalizeGeminiModel(model) + requestBody := p.buildRequestBody(messages, tools, model, options) + jsonData, err := json.Marshal(requestBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + url := fmt.Sprintf("%s/models/%s:streamGenerateContent?alt=sse", p.apiBase, model) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + p.applyHeaders(req) + req.Header.Set("Accept", "text/event-stream") + + // Streaming should not use a whole-request timeout; context cancellation is the guard. + streamClient := &http.Client{Transport: p.httpClient.Transport} + resp, err := streamClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, common.HandleErrorResponse(resp, p.apiBase) + } + + return parseGeminiStreamResponse(ctx, resp.Body, onChunk) +} + +func (p *GeminiProvider) applyHeaders(req *http.Request) { + req.Header.Set("Content-Type", "application/json") + if p.apiKey != "" { + req.Header.Set("x-goog-api-key", p.apiKey) + } + if p.userAgent != "" { + req.Header.Set("User-Agent", p.userAgent) + } + for k, v := range p.customHeaders { + if strings.TrimSpace(k) == "" { + continue + } + req.Header.Set(k, v) + } +} + +func (p *GeminiProvider) buildRequestBody( + messages []Message, + tools []ToolDefinition, + model string, + options map[string]any, +) map[string]any { + contents := make([]geminiContent, 0, len(messages)) + toolCallNames := make(map[string]string) + var systemInstruction *geminiContent + + for _, msg := range messages { + switch msg.Role { + case "system": + if strings.TrimSpace(msg.Content) != "" { + systemInstruction = &geminiContent{Parts: []geminiPart{{Text: msg.Content}}} + } + + case "user": + if msg.ToolCallID != "" { + toolName := resolveToolResponseName(msg.ToolCallID, toolCallNames) + contents = append(contents, geminiContent{ + Role: "user", + Parts: []geminiPart{{ + FunctionResponse: buildGeminiFunctionResponse(toolName, msg.ToolCallID, msg.Content, msg.Media), + }}, + }) + continue + } + + parts := make([]geminiPart, 0, 1+len(msg.Media)) + if strings.TrimSpace(msg.Content) != "" { + parts = append(parts, geminiPart{Text: msg.Content}) + } + parts = append(parts, buildInlineMediaParts(msg.Media)...) + if len(parts) > 0 { + contents = append(contents, geminiContent{Role: "user", Parts: parts}) + } + + case "assistant": + content := geminiContent{Role: "model"} + if strings.TrimSpace(msg.Content) != "" { + content.Parts = append(content.Parts, geminiPart{Text: msg.Content}) + } + for _, tc := range msg.ToolCalls { + toolName, toolArgs, thoughtSignature := normalizeStoredToolCall(tc) + if toolName == "" { + continue + } + if tc.ID != "" { + toolCallNames[tc.ID] = toolName + } + part := geminiPart{ + FunctionCall: &geminiFunctionCall{ + Name: toolName, + Args: toolArgs, + ID: tc.ID, + }, + } + if thoughtSignature != "" { + part.ThoughtSignature = thoughtSignature + part.ThoughtSignatureSnake = thoughtSignature + } + content.Parts = append(content.Parts, part) + } + if len(content.Parts) > 0 { + contents = append(contents, content) + } + + case "tool": + toolName := resolveToolResponseName(msg.ToolCallID, toolCallNames) + contents = append(contents, geminiContent{ + Role: "user", + Parts: []geminiPart{{ + FunctionResponse: buildGeminiFunctionResponse(toolName, msg.ToolCallID, msg.Content, msg.Media), + }}, + }) + } + } + + body := map[string]any{ + "contents": contents, + } + if systemInstruction != nil { + body["systemInstruction"] = systemInstruction + } + + if len(tools) > 0 { + funcDecls := make([]geminiFunctionDeclaration, 0, len(tools)) + for _, t := range tools { + if t.Type != "function" { + continue + } + funcDecls = append(funcDecls, geminiFunctionDeclaration{ + Name: t.Function.Name, + Description: t.Function.Description, + Parameters: sanitizeSchemaForGemini(t.Function.Parameters), + }) + } + if len(funcDecls) > 0 { + body["tools"] = []geminiTool{{FunctionDeclarations: funcDecls}} + } + } + + generationConfig := make(map[string]any) + if val, ok := options["max_tokens"]; ok { + if maxTokens, ok := val.(int); ok && maxTokens > 0 { + generationConfig["maxOutputTokens"] = maxTokens + } else if maxTokens, ok := val.(float64); ok && maxTokens > 0 { + generationConfig["maxOutputTokens"] = int(maxTokens) + } + } + if temp, ok := options["temperature"].(float64); ok { + generationConfig["temperature"] = temp + } + + if thinkingConfig := buildGeminiThinkingConfig(model, options); len(thinkingConfig) > 0 { + generationConfig["thinkingConfig"] = thinkingConfig + } + + if len(generationConfig) > 0 { + body["generationConfig"] = generationConfig + } + + for k, v := range p.extraBody { + body[k] = v + } + + return body +} + +func normalizeGeminiModel(model string) string { + model = strings.TrimSpace(model) + model = strings.TrimPrefix(model, "models/") + if strings.Contains(model, "/") { + _, modelID := ExtractProtocol(model) + if modelID != "" { + return modelID + } + } + if model == "" { + return geminiDefaultModel + } + return model +} + +func mapGeminiThinkingLevel(level string) string { + switch strings.ToLower(strings.TrimSpace(level)) { + case "minimal", "off": + return "minimal" + case "low": + return "low" + case "medium": + return "medium" + case "high", "xhigh", "adaptive": + return "high" + default: + return "" + } +} + +func buildGeminiThinkingConfig(model string, options map[string]any) map[string]any { + if !geminiModelSupportsThinkingConfig(model) { + return nil + } + + config := map[string]any{"includeThoughts": true} + rawLevel, _ := options["thinking_level"].(string) + rawLevel = strings.ToLower(strings.TrimSpace(rawLevel)) + + if isGemini25Model(model) { + if budget, ok := mapGeminiThinkingBudget(rawLevel, model); ok { + config["thinkingBudget"] = budget + } + return config + } + + if thinkingLevel := mapGeminiThinkingLevel(rawLevel); thinkingLevel != "" { + config["thinkingLevel"] = thinkingLevel + } + return config +} + +func geminiModelSupportsThinkingConfig(model string) bool { + lowerModel := strings.ToLower(strings.TrimSpace(model)) + return strings.Contains(lowerModel, "gemini-3") || isGemini25Model(lowerModel) +} + +func isGemini25Model(model string) bool { + lowerModel := strings.ToLower(strings.TrimSpace(model)) + return strings.Contains(lowerModel, "gemini-2.5") || strings.Contains(lowerModel, "gemini-25") +} + +func mapGeminiThinkingBudget(level string, model string) (int, bool) { + level = strings.ToLower(strings.TrimSpace(level)) + if level == "" { + return 0, false + } + + switch level { + case "adaptive": + return -1, true + case "minimal": + if strings.Contains(strings.ToLower(model), "pro") { + return 128, true + } + return 0, true + case "off": + if strings.Contains(strings.ToLower(model), "pro") { + // Gemini 2.5 Pro cannot disable thinking; use the lowest supported budget. + return 128, true + } + return 0, true + case "low": + return 1024, true + case "medium": + return 4096, true + case "high": + return 8192, true + case "xhigh": + return 16384, true + default: + return 0, false + } +} + +func parseGeminiResponse(resp *geminiGenerateContentResponse) *LLMResponse { + contentParts := make([]string, 0) + reasoningParts := make([]string, 0) + toolCalls := make([]ToolCall, 0) + finishReason := "" + + for _, candidate := range resp.Candidates { + for _, part := range candidate.Content.Parts { + if part.Text != "" { + if part.Thought { + reasoningParts = append(reasoningParts, part.Text) + } else { + contentParts = append(contentParts, part.Text) + } + } + if part.FunctionCall != nil { + toolCalls = append(toolCalls, buildGeminiToolCall(part)) + } + } + if candidate.FinishReason != "" { + finishReason = candidate.FinishReason + } + } + + var usage *UsageInfo + if resp.UsageMetadata.TotalTokenCount > 0 { + usage = &UsageInfo{ + PromptTokens: resp.UsageMetadata.PromptTokenCount, + CompletionTokens: resp.UsageMetadata.CandidatesTokenCount, + TotalTokens: resp.UsageMetadata.TotalTokenCount, + } + } + + return &LLMResponse{ + Content: strings.Join(contentParts, ""), + ReasoningContent: strings.Join(reasoningParts, ""), + ToolCalls: toolCalls, + FinishReason: normalizeGeminiFinishReason(finishReason, len(toolCalls)), + Usage: usage, + } +} + +func parseGeminiStreamResponse( + ctx context.Context, + reader io.Reader, + onChunk func(accumulated string), +) (*LLMResponse, error) { + var contentBuilder strings.Builder + var reasoningBuilder strings.Builder + var finishReason string + var usage *UsageInfo + + toolCallsByID := make(map[string]ToolCall) + toolCallOrder := make([]string, 0) + fallbackIndex := 0 + + scanner := bufio.NewScanner(reader) + scanner.Buffer(make([]byte, 0, 1024*1024), 10*1024*1024) + for scanner.Scan() { + if err := ctx.Err(); err != nil { + return nil, err + } + + line := scanner.Text() + if !strings.HasPrefix(line, "data: ") { + continue + } + data := strings.TrimPrefix(line, "data: ") + if data == "[DONE]" { + break + } + + var chunk geminiGenerateContentResponse + if err := json.Unmarshal([]byte(data), &chunk); err != nil { + continue + } + + for _, candidate := range chunk.Candidates { + for _, part := range candidate.Content.Parts { + if part.Text != "" { + if part.Thought { + reasoningBuilder.WriteString(part.Text) + } else { + contentBuilder.WriteString(part.Text) + if onChunk != nil { + onChunk(contentBuilder.String()) + } + } + } + if part.FunctionCall != nil { + tc := buildGeminiToolCall(part) + key := tc.ID + if strings.TrimSpace(key) == "" { + fallbackIndex++ + key = fmt.Sprintf("%s#%d", tc.Name, fallbackIndex) + tc.ID = key + } + if _, exists := toolCallsByID[key]; !exists { + toolCallOrder = append(toolCallOrder, key) + } + toolCallsByID[key] = tc + } + } + if candidate.FinishReason != "" { + finishReason = candidate.FinishReason + } + } + + if chunk.UsageMetadata.TotalTokenCount > 0 { + usage = &UsageInfo{ + PromptTokens: chunk.UsageMetadata.PromptTokenCount, + CompletionTokens: chunk.UsageMetadata.CandidatesTokenCount, + TotalTokens: chunk.UsageMetadata.TotalTokenCount, + } + } + } + + if err := scanner.Err(); err != nil { + return nil, fmt.Errorf("streaming read error: %w", err) + } + + toolCalls := make([]ToolCall, 0, len(toolCallOrder)) + for _, key := range toolCallOrder { + toolCalls = append(toolCalls, toolCallsByID[key]) + } + + return &LLMResponse{ + Content: contentBuilder.String(), + ReasoningContent: reasoningBuilder.String(), + ToolCalls: toolCalls, + FinishReason: normalizeGeminiFinishReason(finishReason, len(toolCalls)), + Usage: usage, + }, nil +} + +func normalizeGeminiFinishReason(reason string, toolCalls int) string { + if toolCalls > 0 { + return "tool_calls" + } + + switch strings.ToUpper(strings.TrimSpace(reason)) { + case "MAX_TOKENS": + return "length" + case "", "STOP": + return "stop" + default: + return strings.ToLower(strings.TrimSpace(reason)) + } +} + +func buildGeminiToolCall(part geminiPart) ToolCall { + if part.FunctionCall == nil { + return ToolCall{} + } + + args := part.FunctionCall.Args + if args == nil { + args = make(map[string]any) + } + argsJSON, _ := json.Marshal(args) + thoughtSignature := extractPartThoughtSignature(part.ThoughtSignature, part.ThoughtSignatureSnake) + + toolCall := ToolCall{ + ID: part.FunctionCall.ID, + Name: part.FunctionCall.Name, + Arguments: args, + ThoughtSignature: thoughtSignature, + Function: &FunctionCall{ + Name: part.FunctionCall.Name, + Arguments: string(argsJSON), + ThoughtSignature: thoughtSignature, + }, + } + + if thoughtSignature != "" { + toolCall.ExtraContent = &ExtraContent{ + Google: &GoogleExtra{ThoughtSignature: thoughtSignature}, + } + } + if strings.TrimSpace(toolCall.ID) == "" { + toolCall.ID = fmt.Sprintf("call_%s_%d", toolCall.Name, time.Now().UnixNano()) + } + + return toolCall +} + +func buildInlineMediaParts(media []string) []geminiPart { + parts := make([]geminiPart, 0, len(media)) + for _, mediaURL := range media { + mimeType, data, ok := parseBase64DataURL(mediaURL) + if !ok { + continue + } + parts = append(parts, geminiPart{ + InlineData: &geminiInlineData{ + MIMEType: mimeType, + Data: data, + }, + }) + } + return parts +} + +func buildGeminiFunctionResponse( + toolName string, + toolCallID string, + result string, + media []string, +) *geminiFunctionResponse { + response := &geminiFunctionResponse{ + ID: toolCallID, + Name: toolName, + Response: map[string]any{ + "result": result, + }, + } + + if parts := buildFunctionResponseMediaParts(media); len(parts) > 0 { + response.Parts = parts + } + + return response +} + +func buildFunctionResponseMediaParts(media []string) []geminiFunctionResponsePart { + parts := make([]geminiFunctionResponsePart, 0, len(media)) + for i, mediaURL := range media { + mimeType, data, ok := parseBase64DataURL(mediaURL) + if !ok { + continue + } + parts = append(parts, geminiFunctionResponsePart{ + InlineData: &geminiInlineData{ + MIMEType: mimeType, + Data: data, + DisplayName: defaultFunctionResponseDisplayName(mimeType, i+1), + }, + }) + } + return parts +} + +func defaultFunctionResponseDisplayName(mimeType string, index int) string { + suffix := "bin" + switch strings.ToLower(strings.TrimSpace(mimeType)) { + case "image/png": + suffix = "png" + case "image/jpeg": + suffix = "jpg" + case "image/webp": + suffix = "webp" + case "application/pdf": + suffix = "pdf" + case "text/plain": + suffix = "txt" + } + return fmt.Sprintf("attachment-%d.%s", index, suffix) +} + +func parseBase64DataURL(mediaURL string) (mimeType string, data string, ok bool) { + if !strings.HasPrefix(mediaURL, "data:") { + return "", "", false + } + + payload := strings.TrimPrefix(mediaURL, "data:") + header, data, found := strings.Cut(payload, ",") + if !found { + return "", "", false + } + mimeType, params, _ := strings.Cut(header, ";") + mimeType = strings.TrimSpace(mimeType) + data = strings.TrimSpace(data) + if mimeType == "" || data == "" { + return "", "", false + } + if !strings.Contains(strings.ToLower(params), "base64") { + return "", "", false + } + return mimeType, data, true +} + +func cloneAnyMap(in map[string]any) map[string]any { + if len(in) == 0 { + return nil + } + out := make(map[string]any, len(in)) + for k, v := range in { + out[k] = v + } + return out +} + +func cloneStringMap(in map[string]string) map[string]string { + if len(in) == 0 { + return nil + } + out := make(map[string]string, len(in)) + for k, v := range in { + out[k] = v + } + return out +} + +type geminiGenerateContentResponse struct { + Candidates []struct { + Content struct { + Role string `json:"role"` + Parts []geminiPart `json:"parts"` + } `json:"content"` + FinishReason string `json:"finishReason"` + } `json:"candidates"` + UsageMetadata struct { + PromptTokenCount int `json:"promptTokenCount"` + CandidatesTokenCount int `json:"candidatesTokenCount"` + TotalTokenCount int `json:"totalTokenCount"` + } `json:"usageMetadata"` +} + +type geminiContent struct { + Role string `json:"role,omitempty"` + Parts []geminiPart `json:"parts"` +} + +type geminiPart struct { + Text string `json:"text,omitempty"` + Thought bool `json:"thought,omitempty"` + ThoughtSignature string `json:"thoughtSignature,omitempty"` + ThoughtSignatureSnake string `json:"thought_signature,omitempty"` + InlineData *geminiInlineData `json:"inlineData,omitempty"` + FunctionCall *geminiFunctionCall `json:"functionCall,omitempty"` + FunctionResponse *geminiFunctionResponse `json:"functionResponse,omitempty"` +} + +type geminiInlineData struct { + MIMEType string `json:"mimeType"` + Data string `json:"data"` + DisplayName string `json:"displayName,omitempty"` +} + +type geminiFunctionCall struct { + ID string `json:"id,omitempty"` + Name string `json:"name"` + Args map[string]any `json:"args,omitempty"` +} + +type geminiFunctionResponse struct { + ID string `json:"id,omitempty"` + Name string `json:"name"` + Response map[string]any `json:"response"` + Parts []geminiFunctionResponsePart `json:"parts,omitempty"` +} + +type geminiFunctionResponsePart struct { + InlineData *geminiInlineData `json:"inlineData,omitempty"` +} + +type geminiTool struct { + FunctionDeclarations []geminiFunctionDeclaration `json:"functionDeclarations"` +} + +type geminiFunctionDeclaration struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + Parameters any `json:"parameters,omitempty"` +} diff --git a/pkg/providers/gemini_provider_test.go b/pkg/providers/gemini_provider_test.go new file mode 100644 index 000000000..c1bdf7c7f --- /dev/null +++ b/pkg/providers/gemini_provider_test.go @@ -0,0 +1,440 @@ +package providers + +import ( + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestGeminiProvider_ChatSeparatesThoughtAndToolCall(t *testing.T) { + var capturedBody map[string]any + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + t.Fatalf("method = %s, want POST", r.Method) + } + if !strings.Contains(r.URL.Path, ":generateContent") { + t.Fatalf("path = %s, expected generateContent endpoint", r.URL.Path) + } + if got := r.Header.Get("x-goog-api-key"); got != "test-key" { + t.Fatalf("x-goog-api-key = %q, want %q", got, "test-key") + } + if err := json.NewDecoder(r.Body).Decode(&capturedBody); err != nil { + t.Fatalf("decode request body: %v", err) + } + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "candidates": []any{ + map[string]any{ + "content": map[string]any{ + "role": "model", + "parts": []any{ + map[string]any{"text": "hidden", "thought": true}, + map[string]any{"text": "visible"}, + map[string]any{ + "functionCall": map[string]any{ + "id": "call_1", + "name": "search", + "args": map[string]any{"q": "hi"}, + }, + "thoughtSignature": "sig-1", + }, + }, + }, + "finishReason": "STOP", + }, + }, + "usageMetadata": map[string]any{ + "promptTokenCount": 2, + "candidatesTokenCount": 3, + "totalTokenCount": 5, + }, + }) + })) + defer server.Close() + + provider := NewGeminiProvider("test-key", server.URL, "", "picoclaw-test", 0, nil, nil) + resp, err := provider.Chat( + t.Context(), + []Message{{Role: "user", Content: "hello"}}, + nil, + "gemini-3-flash-preview", + map[string]any{"thinking_level": "high"}, + ) + if err != nil { + t.Fatalf("Chat() error = %v", err) + } + if resp.Content != "visible" { + t.Fatalf("Content = %q, want %q", resp.Content, "visible") + } + if resp.ReasoningContent != "hidden" { + t.Fatalf("ReasoningContent = %q, want %q", resp.ReasoningContent, "hidden") + } + if resp.FinishReason != "tool_calls" { + t.Fatalf("FinishReason = %q, want %q", resp.FinishReason, "tool_calls") + } + if resp.Usage == nil || resp.Usage.TotalTokens != 5 { + t.Fatalf("Usage = %#v, expected total tokens = 5", resp.Usage) + } + if len(resp.ToolCalls) != 1 { + t.Fatalf("ToolCalls len = %d, want 1", len(resp.ToolCalls)) + } + if resp.ToolCalls[0].ID != "call_1" { + t.Fatalf("ToolCall ID = %q, want %q", resp.ToolCalls[0].ID, "call_1") + } + if resp.ToolCalls[0].Name != "search" { + t.Fatalf("ToolCall Name = %q, want %q", resp.ToolCalls[0].Name, "search") + } + if resp.ToolCalls[0].ThoughtSignature != "sig-1" { + t.Fatalf("ToolCall ThoughtSignature = %q, want %q", resp.ToolCalls[0].ThoughtSignature, "sig-1") + } + if resp.ToolCalls[0].Function == nil || !strings.Contains(resp.ToolCalls[0].Function.Arguments, `"q":"hi"`) { + t.Fatalf("ToolCall Function arguments = %#v, want q=hi", resp.ToolCalls[0].Function) + } + + generationConfig, ok := capturedBody["generationConfig"].(map[string]any) + if !ok { + t.Fatalf("request missing generationConfig: %#v", capturedBody) + } + thinkingConfig, ok := generationConfig["thinkingConfig"].(map[string]any) + if !ok { + t.Fatalf("request missing thinkingConfig: %#v", generationConfig) + } + if includeThoughts, ok := thinkingConfig["includeThoughts"].(bool); !ok || !includeThoughts { + t.Fatalf("thinkingConfig.includeThoughts = %#v, want true", thinkingConfig["includeThoughts"]) + } + if got := thinkingConfig["thinkingLevel"]; got != "high" { + t.Fatalf("thinkingConfig.thinkingLevel = %#v, want %q", got, "high") + } +} + +func TestGeminiProvider_ChatStreamParsesThoughtTextAndToolCalls(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !strings.Contains(r.URL.Path, ":streamGenerateContent") { + t.Fatalf("path = %s, expected streamGenerateContent endpoint", r.URL.Path) + } + if got := r.URL.Query().Get("alt"); got != "sse" { + t.Fatalf("alt query = %q, want %q", got, "sse") + } + + w.Header().Set("Content-Type", "text/event-stream") + flusher, ok := w.(http.Flusher) + if !ok { + t.Fatal("response writer is not flushable") + } + + chunks := []map[string]any{ + { + "candidates": []any{map[string]any{ + "content": map[string]any{ + "parts": []any{ + map[string]any{"text": "think ", "thought": true}, + map[string]any{"text": "Hello "}, + }, + }, + }}, + }, + { + "candidates": []any{map[string]any{ + "content": map[string]any{ + "parts": []any{ + map[string]any{"text": "World"}, + map[string]any{ + "functionCall": map[string]any{ + "id": "call_stream", + "name": "search", + "args": map[string]any{"q": "stream"}, + }, + }, + }, + }, + "finishReason": "STOP", + }}, + "usageMetadata": map[string]any{ + "promptTokenCount": 1, + "candidatesTokenCount": 2, + "totalTokenCount": 3, + }, + }, + } + + for _, chunk := range chunks { + raw, err := json.Marshal(chunk) + if err != nil { + t.Fatalf("marshal chunk: %v", err) + } + if _, err := fmt.Fprintf(w, "data: %s\n\n", raw); err != nil { + t.Fatalf("write chunk: %v", err) + } + flusher.Flush() + } + _, _ = fmt.Fprint(w, "data: [DONE]\n\n") + flusher.Flush() + })) + defer server.Close() + + provider := NewGeminiProvider("test-key", server.URL, "", "", 0, nil, nil) + updates := make([]string, 0) + resp, err := provider.ChatStream( + t.Context(), + []Message{{Role: "user", Content: "hello"}}, + nil, + "gemini-2.5-flash", + nil, + func(accumulated string) { + updates = append(updates, accumulated) + }, + ) + if err != nil { + t.Fatalf("ChatStream() error = %v", err) + } + if resp.Content != "Hello World" { + t.Fatalf("Content = %q, want %q", resp.Content, "Hello World") + } + if resp.ReasoningContent != "think " { + t.Fatalf("ReasoningContent = %q, want %q", resp.ReasoningContent, "think ") + } + if len(resp.ToolCalls) != 1 || resp.ToolCalls[0].ID != "call_stream" { + t.Fatalf("ToolCalls = %#v, want single call_stream", resp.ToolCalls) + } + if resp.FinishReason != "tool_calls" { + t.Fatalf("FinishReason = %q, want %q", resp.FinishReason, "tool_calls") + } + if resp.Usage == nil || resp.Usage.TotalTokens != 3 { + t.Fatalf("Usage = %#v, expected total tokens = 3", resp.Usage) + } + if len(updates) < 2 || updates[len(updates)-1] != "Hello World" { + t.Fatalf("stream updates = %#v, expected final accumulated text", updates) + } +} + +func TestGeminiProvider_BuildRequestBodyIncludesMediaAndThinkingConfig(t *testing.T) { + provider := NewGeminiProvider("test-key", "https://example.com/v1beta", "", "", 0, nil, nil) + + body := provider.buildRequestBody( + []Message{{ + Role: "user", + Content: "analyze attachments", + Media: []string{ + "data:application/pdf;base64,UEZERGF0YQ==", + "data:image/png;base64,aW1hZ2VEYXRh", + }, + }}, + nil, + "gemini-3-flash-preview", + map[string]any{ + "thinking_level": "low", + "max_tokens": 128, + "temperature": 0.2, + }, + ) + + contents, ok := body["contents"].([]geminiContent) + if !ok || len(contents) != 1 { + t.Fatalf("contents = %#v, want one gemini content", body["contents"]) + } + parts := contents[0].Parts + mimeSet := map[string]bool{} + for _, part := range parts { + if part.InlineData != nil { + mimeSet[part.InlineData.MIMEType] = true + } + } + if !mimeSet["application/pdf"] { + t.Fatalf("inline media missing application/pdf: %#v", parts) + } + if !mimeSet["image/png"] { + t.Fatalf("inline media missing image/png: %#v", parts) + } + + generationConfig, ok := body["generationConfig"].(map[string]any) + if !ok { + t.Fatalf("generationConfig = %#v, want map", body["generationConfig"]) + } + if got := generationConfig["maxOutputTokens"]; got != 128 { + t.Fatalf("maxOutputTokens = %#v, want 128", got) + } + if got := generationConfig["temperature"]; got != 0.2 { + t.Fatalf("temperature = %#v, want 0.2", got) + } + thinkingConfig, ok := generationConfig["thinkingConfig"].(map[string]any) + if !ok { + t.Fatalf("thinkingConfig = %#v, want map", generationConfig["thinkingConfig"]) + } + if includeThoughts, ok := thinkingConfig["includeThoughts"].(bool); !ok || !includeThoughts { + t.Fatalf("includeThoughts = %#v, want true", thinkingConfig["includeThoughts"]) + } + if got := thinkingConfig["thinkingLevel"]; got != "low" { + t.Fatalf("thinkingLevel = %#v, want %q", got, "low") + } +} + +func TestGeminiProvider_BuildRequestBody_UsesThinkingBudgetForGemini25(t *testing.T) { + provider := NewGeminiProvider("test-key", "https://example.com/v1beta", "", "", 0, nil, nil) + body := provider.buildRequestBody( + []Message{{Role: "user", Content: "hello"}}, + nil, + "gemini-2.5-flash", + map[string]any{"thinking_level": "medium"}, + ) + + generationConfig, ok := body["generationConfig"].(map[string]any) + if !ok { + t.Fatalf("generationConfig = %#v, want map", body["generationConfig"]) + } + thinkingConfig, ok := generationConfig["thinkingConfig"].(map[string]any) + if !ok { + t.Fatalf("thinkingConfig = %#v, want map", generationConfig["thinkingConfig"]) + } + if got := thinkingConfig["thinkingBudget"]; got != 4096 { + t.Fatalf("thinkingBudget = %#v, want 4096", got) + } + if _, hasLevel := thinkingConfig["thinkingLevel"]; hasLevel { + t.Fatalf("thinkingLevel should not be set for Gemini 2.5: %#v", thinkingConfig) + } +} + +func TestGeminiProvider_BuildRequestBody_OmitsThinkingConfigForGemini20(t *testing.T) { + provider := NewGeminiProvider("test-key", "https://example.com/v1beta", "", "", 0, nil, nil) + body := provider.buildRequestBody( + []Message{{Role: "user", Content: "hello"}}, + nil, + "gemini-2.0-flash-exp", + map[string]any{"thinking_level": "high"}, + ) + + if _, ok := body["generationConfig"]; ok { + t.Fatalf("generationConfig should be omitted for Gemini 2.0 when only thinking_level is set: %#v", body) + } +} + +func TestGeminiProvider_BuildRequestBody_PreservesToolResponseMedia(t *testing.T) { + provider := NewGeminiProvider("test-key", "https://example.com/v1beta", "", "", 0, nil, nil) + body := provider.buildRequestBody( + []Message{ + { + Role: "assistant", + ToolCalls: []ToolCall{{ + ID: "call_1", + Name: "load_image", + Arguments: map[string]any{"path": "demo.png"}, + }}, + }, + { + Role: "tool", + ToolCallID: "call_1", + Content: "tool result", + Media: []string{ + "data:image/png;base64,aW1hZ2VEYXRh", + "data:application/pdf;base64,UEZERGF0YQ==", + }, + }, + }, + nil, + "gemini-3-flash-preview", + nil, + ) + + contents, ok := body["contents"].([]geminiContent) + if !ok || len(contents) != 2 { + t.Fatalf("contents = %#v, want two content entries", body["contents"]) + } + parts := contents[1].Parts + if len(parts) != 1 || parts[0].FunctionResponse == nil { + t.Fatalf("tool response part = %#v, want functionResponse", parts) + } + response := parts[0].FunctionResponse + if response.Name != "load_image" { + t.Fatalf("functionResponse.Name = %q, want %q", response.Name, "load_image") + } + if response.Response["result"] != "tool result" { + t.Fatalf("functionResponse.Response = %#v, want result=tool result", response.Response) + } + if len(response.Parts) != 2 { + t.Fatalf("functionResponse.Parts len = %d, want 2", len(response.Parts)) + } +} + +func TestGeminiProvider_ChatAllowsCustomAuthHeaderWithoutAPIKey(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if got := r.Header.Get("Authorization"); got != "Bearer test-token" { + t.Fatalf("Authorization = %q, want %q", got, "Bearer test-token") + } + if got := r.Header.Get("x-goog-api-key"); got != "" { + t.Fatalf("x-goog-api-key = %q, want empty", got) + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "candidates": []any{ + map[string]any{ + "content": map[string]any{ + "parts": []any{map[string]any{"text": "ok"}}, + }, + "finishReason": "STOP", + }, + }, + }) + })) + defer server.Close() + + provider := NewGeminiProvider( + "", + server.URL, + "", + "", + 0, + nil, + map[string]string{"Authorization": "Bearer test-token"}, + ) + + resp, err := provider.Chat( + t.Context(), + []Message{{Role: "user", Content: "hello"}}, + nil, + "gemini-2.5-flash", + nil, + ) + if err != nil { + t.Fatalf("Chat() error = %v", err) + } + if resp.Content != "ok" { + t.Fatalf("Content = %q, want %q", resp.Content, "ok") + } +} + +func TestGeminiProvider_ChatAllowsMissingAPIKeyForCustomAPIBase(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if got := r.Header.Get("x-goog-api-key"); got != "" { + t.Fatalf("x-goog-api-key = %q, want empty", got) + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "candidates": []any{ + map[string]any{ + "content": map[string]any{"parts": []any{map[string]any{"text": "ok"}}}, + "finishReason": "STOP", + }, + }, + }) + })) + defer server.Close() + + provider := NewGeminiProvider("", server.URL, "", "", 0, nil, nil) + resp, err := provider.Chat( + t.Context(), + []Message{{Role: "user", Content: "hello"}}, + nil, + "gemini-2.5-flash", + nil, + ) + if err != nil { + t.Fatalf("Chat() error = %v", err) + } + if resp.Content != "ok" { + t.Fatalf("Content = %q, want %q", resp.Content, "ok") + } +} diff --git a/pkg/providers/openai_compat/provider.go b/pkg/providers/openai_compat/provider.go index d25a0fce4..98a70cfd2 100644 --- a/pkg/providers/openai_compat/provider.go +++ b/pkg/providers/openai_compat/provider.go @@ -8,6 +8,7 @@ import ( "fmt" "io" "log" + "maps" "net/http" "net/url" "strings" @@ -181,9 +182,7 @@ func (p *Provider) buildRequestBody( // Merge extra body fields configured per-provider/model. // These are injected last so they take precedence over defaults. - for k, v := range p.extraBody { - requestBody[k] = v - } + maps.Copy(requestBody, p.extraBody) return requestBody } diff --git a/web/backend/api/session.go b/web/backend/api/session.go index ae580d9aa..9bb6055e2 100644 --- a/web/backend/api/session.go +++ b/web/backend/api/session.go @@ -281,6 +281,12 @@ func visibleSessionMessages(messages []providers.Message, toolFeedbackMaxArgsLen } case "assistant": + // Reasoning-only assistant messages are transient display artifacts and + // should not be restored from session history. + if assistantMessageTransientThought(msg) { + continue + } + toolSummaryMessages := visibleAssistantToolSummaryMessages(msg.ToolCalls, toolFeedbackMaxArgsLength) if len(toolSummaryMessages) > 0 { transcript = append(transcript, toolSummaryMessages...) @@ -309,6 +315,13 @@ func visibleSessionMessages(messages []providers.Message, toolFeedbackMaxArgsLen return transcript } +func assistantMessageTransientThought(msg providers.Message) bool { + return strings.TrimSpace(msg.Content) == "" && + strings.TrimSpace(msg.ReasoningContent) != "" && + len(msg.ToolCalls) == 0 && + len(msg.Media) == 0 +} + func assistantMessageInternalOnly(msg providers.Message) bool { return strings.TrimSpace(msg.Content) == handledToolResponseSummaryText } diff --git a/web/backend/api/session_test.go b/web/backend/api/session_test.go index 5d7620362..599921bfe 100644 --- a/web/backend/api/session_test.go +++ b/web/backend/api/session_test.go @@ -218,6 +218,59 @@ func TestHandleGetSession_JSONLStorage(t *testing.T) { } } +func TestHandleGetSession_OmitsTransientThoughtMessages(t *testing.T) { + configPath, cleanup := setupOAuthTestEnv(t) + defer cleanup() + + dir := sessionsTestDir(t, configPath) + store, err := memory.NewJSONLStore(dir) + if err != nil { + t.Fatalf("NewJSONLStore() error = %v", err) + } + + sessionKey := picoSessionPrefix + "detail-transient-thought" + for _, msg := range []providers.Message{ + {Role: "user", Content: "hello"}, + {Role: "assistant", ReasoningContent: "internal chain of thought"}, + {Role: "assistant", Content: "final visible answer"}, + } { + if err := store.AddFullMessage(nil, sessionKey, msg); err != nil { + t.Fatalf("AddFullMessage() error = %v", err) + } + } + + h := NewHandler(configPath) + mux := http.NewServeMux() + h.RegisterRoutes(mux) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/api/sessions/detail-transient-thought", nil) + mux.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String()) + } + + var resp struct { + Messages []struct { + Role string `json:"role"` + Content string `json:"content"` + } `json:"messages"` + } + if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil { + t.Fatalf("Unmarshal() error = %v", err) + } + if len(resp.Messages) != 2 { + t.Fatalf("len(resp.Messages) = %d, want 2", len(resp.Messages)) + } + if resp.Messages[0].Role != "user" || resp.Messages[0].Content != "hello" { + t.Fatalf("first message = %#v, want user/hello", resp.Messages[0]) + } + if resp.Messages[1].Role != "assistant" || resp.Messages[1].Content != "final visible answer" { + t.Fatalf("second message = %#v, want assistant/final visible answer", resp.Messages[1]) + } +} + func TestHandleGetSession_ReconstructsVisibleMessageToolOutput(t *testing.T) { configPath, cleanup := setupOAuthTestEnv(t) defer cleanup()