diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index ac230aa86..a856c0fca 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -105,6 +105,8 @@ const ( toolLimitResponse = "I've reached `max_tool_iterations` without a final response. Increase `max_tool_iterations` in config.json if this task needs more tool steps." handledToolResponseSummary = "Requested output delivered via tool attachment." sessionKeyAgentPrefix = "agent:" + metadataKeyMessageKind = "message_kind" + messageKindThought = "thought" metadataKeyAccountID = "account_id" metadataKeyGuildID = "guild_id" metadataKeyTeamID = "team_id" @@ -1622,6 +1624,41 @@ func (al *AgentLoop) targetReasoningChannelID(channelName string) (chatID string return "" } +func (al *AgentLoop) publishPicoReasoning(ctx context.Context, reasoningContent, chatID string) { + if reasoningContent == "" || chatID == "" { + return + } + + if ctx.Err() != nil { + return + } + + pubCtx, pubCancel := context.WithTimeout(ctx, 5*time.Second) + defer pubCancel() + + if err := al.bus.PublishOutbound(pubCtx, bus.OutboundMessage{ + Channel: "pico", + ChatID: chatID, + Content: reasoningContent, + Metadata: map[string]string{ + metadataKeyMessageKind: messageKindThought, + }, + }); err != nil { + if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) || + errors.Is(err, bus.ErrBusClosed) { + logger.DebugCF("agent", "Pico reasoning publish skipped (timeout/cancel)", map[string]any{ + "channel": "pico", + "error": err.Error(), + }) + } else { + logger.WarnCF("agent", "Failed to publish pico reasoning (best-effort)", map[string]any{ + "channel": "pico", + "error": err.Error(), + }) + } + } +} + func (al *AgentLoop) handleReasoning( ctx context.Context, reasoningContent, channelName, channelID string, @@ -2223,12 +2260,16 @@ turnLoop: if reasoningContent == "" { reasoningContent = response.ReasoningContent } - go al.handleReasoning( - turnCtx, - reasoningContent, - ts.channel, - al.targetReasoningChannelID(ts.channel), - ) + if ts.channel == "pico" { + go al.publishPicoReasoning(turnCtx, reasoningContent, ts.chatID) + } else { + go al.handleReasoning( + turnCtx, + reasoningContent, + ts.channel, + al.targetReasoningChannelID(ts.channel), + ) + } al.emitEvent( EventKindLLMResponse, ts.eventMeta("runTurn", "turn.llm.response"), @@ -2277,7 +2318,7 @@ turnLoop: if len(response.ToolCalls) == 0 || gracefulTerminal { responseContent := response.Content - if responseContent == "" && response.ReasoningContent != "" { + if responseContent == "" && response.ReasoningContent != "" && ts.channel != "pico" { responseContent = response.ReasoningContent } if steerMsgs := al.dequeueSteeringMessagesForScope(ts.sessionKey); len(steerMsgs) > 0 { diff --git a/pkg/agent/loop_test.go b/pkg/agent/loop_test.go index a67c8d040..7fe5836b3 100644 --- a/pkg/agent/loop_test.go +++ b/pkg/agent/loop_test.go @@ -1921,7 +1921,7 @@ func TestProcessMessage_FallbackUsesPerCandidateProvider(t *testing.T) { }, { ModelName: "gemma-fallback", - Model: "gemini/gemma-3-27b-it", + Model: "openrouter/gemma-3-27b-it", APIBase: fallbackServer.URL, APIKeys: config.SimpleSecureStrings("fallback-key"), Workspace: workspace, @@ -2660,6 +2660,62 @@ func TestProcessMessage_PublishesReasoningContentToReasoningChannel(t *testing.T } } +func TestProcessMessage_PicoPublishesReasoningAsThoughtMessage(t *testing.T) { + tmpDir := t.TempDir() + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + ModelName: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + + msgBus := bus.NewMessageBus() + provider := &reasoningContentProvider{ + response: "final answer", + reasoningContent: "thinking trace", + } + al := NewAgentLoop(cfg, msgBus, provider) + + response, err := al.processMessage(context.Background(), bus.InboundMessage{ + Channel: "pico", + SenderID: "user1", + ChatID: "pico:test-session", + Content: "hello", + }) + if err != nil { + t.Fatalf("processMessage() error = %v", err) + } + if response != "final answer" { + t.Fatalf("processMessage() response = %q, want %q", response, "final answer") + } + + var thoughtMsg *bus.OutboundMessage + deadline := time.After(3 * time.Second) + + for thoughtMsg == nil { + select { + case outbound := <-msgBus.OutboundChan(): + msg := outbound + if msg.Content == "thinking trace" { + thoughtMsg = &msg + } + case <-deadline: + t.Fatal("expected thought outbound message for pico") + } + } + + if thoughtMsg.Channel != "pico" || thoughtMsg.ChatID != "pico:test-session" { + t.Fatalf("thought message route = %s/%s, want pico/pico:test-session", thoughtMsg.Channel, thoughtMsg.ChatID) + } + if thoughtMsg.Metadata[metadataKeyMessageKind] != messageKindThought { + t.Fatalf("thought metadata kind = %q, want %q", thoughtMsg.Metadata[metadataKeyMessageKind], messageKindThought) + } +} + func TestProcessHeartbeat_DoesNotPublishToolFeedback(t *testing.T) { tmpDir := t.TempDir() heartbeatFile := filepath.Join(tmpDir, "heartbeat-task.txt") diff --git a/pkg/channels/pico/client.go b/pkg/channels/pico/client.go index b4bfd09e5..bf3e38cf4 100644 --- a/pkg/channels/pico/client.go +++ b/pkg/channels/pico/client.go @@ -242,7 +242,11 @@ func (c *PicoClientChannel) handleInbound(pc *picoConn, msg PicoMessage) { } func (c *PicoClientChannel) handleServerMessage(pc *picoConn, msg PicoMessage) { - content, _ := msg.Payload["content"].(string) + if isThoughtPayload(msg.Payload) { + return + } + + content, _ := msg.Payload[PayloadKeyContent].(string) if strings.TrimSpace(content) == "" { return } @@ -285,7 +289,7 @@ func (c *PicoClientChannel) Send(ctx context.Context, msg bus.OutboundMessage) ( } outMsg := newMessage(TypeMessageSend, map[string]any{ - "content": msg.Content, + PayloadKeyContent: msg.Content, }) outMsg.SessionID = strings.TrimPrefix(msg.ChatID, "pico_client:") return nil, pc.writeJSON(outMsg) diff --git a/pkg/channels/pico/client_test.go b/pkg/channels/pico/client_test.go index b40606647..732589432 100644 --- a/pkg/channels/pico/client_test.go +++ b/pkg/channels/pico/client_test.go @@ -316,3 +316,67 @@ func TestPicoChannel_HandleMessageSend_AllowsMediaOnly(t *testing.T) { t.Fatal("timed out waiting for inbound media message") } } + +func TestIsThoughtPayload(t *testing.T) { + tests := []struct { + name string + payload map[string]any + want bool + }{ + { + name: "explicit thought bool", + payload: map[string]any{PayloadKeyThought: true}, + want: true, + }, + { + name: "thought false", + payload: map[string]any{PayloadKeyThought: false}, + want: false, + }, + { + name: "thought string ignored", + payload: map[string]any{PayloadKeyThought: "true"}, + want: false, + }, + { + name: "default normal", + payload: map[string]any{PayloadKeyContent: "hello"}, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := isThoughtPayload(tt.payload); got != tt.want { + t.Fatalf("isThoughtPayload() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestPicoClientChannel_HandleServerMessage_IgnoresThought(t *testing.T) { + mb := bus.NewMessageBus() + ch, err := NewPicoClientChannel(config.PicoClientConfig{ + URL: "ws://localhost:8080/ws", + }, mb) + if err != nil { + t.Fatalf("NewPicoClientChannel() error = %v", err) + } + + ch.ctx = context.Background() + pc := &picoConn{sessionID: "sess-thought"} + + ch.handleServerMessage(pc, PicoMessage{ + Type: TypeMessageCreate, + Payload: map[string]any{ + PayloadKeyContent: "internal reasoning", + PayloadKeyThought: true, + }, + }) + + select { + case msg := <-mb.InboundChan(): + t.Fatalf("expected no inbound publish for thought payload, got %+v", msg) + case <-time.After(150 * time.Millisecond): + } +} diff --git a/pkg/channels/pico/pico.go b/pkg/channels/pico/pico.go index e22da1ba1..6525c2d4a 100644 --- a/pkg/channels/pico/pico.go +++ b/pkg/channels/pico/pico.go @@ -39,6 +39,13 @@ var allowedInlineImageMIMETypes = map[string]struct{}{ "image/bmp": {}, } +func outboundMessageIsThought(metadata map[string]string) bool { + if len(metadata) == 0 { + return false + } + return strings.EqualFold(strings.TrimSpace(metadata["message_kind"]), MessageKindThought) +} + // writeJSON sends a JSON message to the connection with write locking. func (pc *picoConn) writeJSON(v any) error { if pc.closed.Load() { @@ -247,9 +254,11 @@ func (c *PicoChannel) Send(ctx context.Context, msg bus.OutboundMessage) ([]stri if !c.IsRunning() { return nil, channels.ErrNotRunning } + isThought := outboundMessageIsThought(msg.Metadata) outMsg := newMessage(TypeMessageCreate, map[string]any{ - "content": msg.Content, + PayloadKeyContent: msg.Content, + PayloadKeyThought: isThought, }) return nil, c.broadcastToSession(msg.ChatID, outMsg) @@ -288,8 +297,9 @@ func (c *PicoChannel) SendPlaceholder(ctx context.Context, chatID string) (strin msgID := uuid.New().String() outMsg := newMessage(TypeMessageCreate, map[string]any{ - "content": text, - "message_id": msgID, + PayloadKeyContent: text, + PayloadKeyThought: false, + "message_id": msgID, }) if err := c.broadcastToSession(chatID, outMsg); err != nil { diff --git a/pkg/channels/pico/protocol.go b/pkg/channels/pico/protocol.go index 3f8ba8643..ecdc2d140 100644 --- a/pkg/channels/pico/protocol.go +++ b/pkg/channels/pico/protocol.go @@ -19,6 +19,11 @@ const ( TypePong = "pong" PicoTokenPrefix = "pico-" + + PayloadKeyContent = "content" + PayloadKeyThought = "thought" + + MessageKindThought = "thought" ) // PicoMessage is the wire format for all Pico Protocol messages. @@ -39,6 +44,11 @@ func newMessage(msgType string, payload map[string]any) PicoMessage { } } +func isThoughtPayload(payload map[string]any) bool { + thought, _ := payload[PayloadKeyThought].(bool) + return thought +} + func newErrorWithPayload(code, message string, extra map[string]any) PicoMessage { payload := map[string]any{ "code": code, diff --git a/pkg/providers/antigravity_provider.go b/pkg/providers/antigravity_provider.go index 8a1890212..b5ab847d5 100644 --- a/pkg/providers/antigravity_provider.go +++ b/pkg/providers/antigravity_provider.go @@ -389,6 +389,7 @@ type antigravityJSONResponse struct { Content struct { Parts []struct { Text string `json:"text,omitempty"` + Thought bool `json:"thought,omitempty"` ThoughtSignature string `json:"thoughtSignature,omitempty"` ThoughtSignatureSnake string `json:"thought_signature,omitempty"` FunctionCall *antigravityFunctionCall `json:"functionCall,omitempty"` @@ -406,6 +407,7 @@ type antigravityJSONResponse struct { func (p *AntigravityProvider) parseSSEResponse(body string) (*LLMResponse, error) { var contentParts []string + var reasoningParts []string var toolCalls []ToolCall var usage *UsageInfo var finishReason string @@ -433,7 +435,11 @@ func (p *AntigravityProvider) parseSSEResponse(body string) (*LLMResponse, error for _, candidate := range resp.Candidates { for _, part := range candidate.Content.Parts { if part.Text != "" { - contentParts = append(contentParts, part.Text) + if part.Thought { + reasoningParts = append(reasoningParts, part.Text) + } else { + contentParts = append(contentParts, part.Text) + } } if part.FunctionCall != nil { argumentsJSON, _ := json.Marshal(part.FunctionCall.Args) @@ -475,10 +481,11 @@ func (p *AntigravityProvider) parseSSEResponse(body string) (*LLMResponse, error } return &LLMResponse{ - Content: strings.Join(contentParts, ""), - ToolCalls: toolCalls, - FinishReason: mappedFinish, - Usage: usage, + Content: strings.Join(contentParts, ""), + ReasoningContent: strings.Join(reasoningParts, ""), + ToolCalls: toolCalls, + FinishReason: mappedFinish, + Usage: usage, }, nil } diff --git a/pkg/providers/antigravity_provider_test.go b/pkg/providers/antigravity_provider_test.go index 238765321..9155e2d56 100644 --- a/pkg/providers/antigravity_provider_test.go +++ b/pkg/providers/antigravity_provider_test.go @@ -54,3 +54,27 @@ func TestResolveToolResponseNameInfersNameFromGeneratedCallID(t *testing.T) { t.Fatalf("expected inferred tool name search_docs, got %q", got) } } + +func TestParseSSEResponse_SplitsThoughtAndVisibleContent(t *testing.T) { + p := &AntigravityProvider{} + body := "data: {\"response\":{\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"hidden reasoning\",\"thought\":true},{\"text\":\"visible answer\"}],\"role\":\"model\"},\"finishReason\":\"STOP\"}],\"usageMetadata\":{\"promptTokenCount\":8,\"candidatesTokenCount\":17,\"totalTokenCount\":216}}}\n" + + "data: [DONE]\n" + + resp, err := p.parseSSEResponse(body) + if err != nil { + t.Fatalf("parseSSEResponse() error = %v", err) + } + + if resp.Content != "visible answer" { + t.Fatalf("Content = %q, want %q", resp.Content, "visible answer") + } + if resp.ReasoningContent != "hidden reasoning" { + t.Fatalf("ReasoningContent = %q, want %q", resp.ReasoningContent, "hidden reasoning") + } + if resp.FinishReason != "stop" { + t.Fatalf("FinishReason = %q, want %q", resp.FinishReason, "stop") + } + if resp.Usage == nil || resp.Usage.TotalTokens != 216 { + t.Fatalf("Usage.TotalTokens = %v, want %d", resp.Usage, 216) + } +} diff --git a/pkg/providers/factory_provider.go b/pkg/providers/factory_provider.go index f13dc646c..ab68b326a 100644 --- a/pkg/providers/factory_provider.go +++ b/pkg/providers/factory_provider.go @@ -114,7 +114,7 @@ func ResolveAPIBase(cfg *config.ModelConfig) string { // CreateProviderFromConfig creates a provider based on the ModelConfig. // It uses the protocol prefix in the Model field to determine which provider to create. -// Supported protocol families include OpenAI-compatible prefixes (e.g., openai, openrouter, groq, gemini), +// Supported protocol families include OpenAI-compatible prefixes (e.g., openai, openrouter, groq), // Azure OpenAI, Amazon Bedrock, Anthropic (including messages), and various CLI/compatibility shims. // See the switch on protocol in this function for the authoritative list. // Returns the provider, the model ID (without protocol prefix), and any error. @@ -218,7 +218,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err } return provider, modelID, nil - case "litellm", "lmstudio", "openrouter", "groq", "zhipu", "gemini", "nvidia", "venice", + case "litellm", "lmstudio", "openrouter", "groq", "zhipu", "nvidia", "venice", "ollama", "moonshot", "shengsuanyun", "deepseek", "cerebras", "vivgrid", "volcengine", "vllm", "qwen", "qwen-intl", "qwen-international", "dashscope-intl", "qwen-us", "dashscope-us", "mistral", "avian", "longcat", "modelscope", "novita", @@ -242,6 +242,24 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err cfg.CustomHeaders, ), modelID, nil + case "gemini": + if cfg.APIKey() == "" && cfg.APIBase == "" { + return nil, "", fmt.Errorf("api_key or api_base is required for gemini protocol (model: %s)", cfg.Model) + } + apiBase := cfg.APIBase + if apiBase == "" { + apiBase = getDefaultAPIBase(protocol) + } + return NewGeminiProvider( + cfg.APIKey(), + apiBase, + cfg.Proxy, + userAgent, + cfg.RequestTimeout, + cfg.ExtraBody, + cfg.CustomHeaders, + ), modelID, nil + case "minimax": // Minimax requires reasoning_split: true in the request body if cfg.APIKey() == "" && cfg.APIBase == "" { diff --git a/pkg/providers/factory_provider_test.go b/pkg/providers/factory_provider_test.go index c362463ae..20cdd8a30 100644 --- a/pkg/providers/factory_provider_test.go +++ b/pkg/providers/factory_provider_test.go @@ -434,6 +434,62 @@ func TestCreateProviderFromConfig_Antigravity(t *testing.T) { } } +func TestCreateProviderFromConfig_Gemini(t *testing.T) { + cfg := &config.ModelConfig{ + ModelName: "test-gemini", + Model: "gemini/gemini-2.5-flash", + } + cfg.SetAPIKey("test-key") + + provider, modelID, err := CreateProviderFromConfig(cfg) + if err != nil { + t.Fatalf("CreateProviderFromConfig() error = %v", err) + } + if provider == nil { + t.Fatal("CreateProviderFromConfig() returned nil provider") + } + if modelID != "gemini-2.5-flash" { + t.Errorf("modelID = %q, want %q", modelID, "gemini-2.5-flash") + } + if _, ok := provider.(*GeminiProvider); !ok { + t.Fatalf("expected *GeminiProvider, got %T", provider) + } +} + +func TestCreateProviderFromConfig_GeminiMissingAPIKey(t *testing.T) { + cfg := &config.ModelConfig{ + ModelName: "test-gemini-no-key", + Model: "gemini/gemini-2.5-flash", + } + + _, _, err := CreateProviderFromConfig(cfg) + if err == nil { + t.Fatal("CreateProviderFromConfig() expected error for missing gemini API key") + } +} + +func TestCreateProviderFromConfig_GeminiCustomAPIBaseWithoutKey(t *testing.T) { + cfg := &config.ModelConfig{ + ModelName: "test-gemini-custom-base", + Model: "gemini/gemini-2.5-flash", + APIBase: "https://proxy.example.com/v1beta", + } + + provider, modelID, err := CreateProviderFromConfig(cfg) + if err != nil { + t.Fatalf("CreateProviderFromConfig() error = %v", err) + } + if provider == nil { + t.Fatal("CreateProviderFromConfig() returned nil provider") + } + if modelID != "gemini-2.5-flash" { + t.Errorf("modelID = %q, want %q", modelID, "gemini-2.5-flash") + } + if _, ok := provider.(*GeminiProvider); !ok { + t.Fatalf("expected *GeminiProvider, got %T", provider) + } +} + func TestCreateProviderFromConfig_ClaudeCLI(t *testing.T) { cfg := &config.ModelConfig{ ModelName: "test-claude-cli", diff --git a/pkg/providers/gemini_provider.go b/pkg/providers/gemini_provider.go new file mode 100644 index 000000000..561387534 --- /dev/null +++ b/pkg/providers/gemini_provider.go @@ -0,0 +1,796 @@ +package providers + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" + + "github.com/sipeed/picoclaw/pkg/providers/common" +) + +const ( + geminiDefaultAPIBase = "https://generativelanguage.googleapis.com/v1beta" + geminiDefaultModel = "gemini-2.0-flash" +) + +type GeminiProvider struct { + apiKey string + apiBase string + httpClient *http.Client + extraBody map[string]any + customHeaders map[string]string + userAgent string +} + +func NewGeminiProvider( + apiKey string, + apiBase string, + proxy string, + userAgent string, + requestTimeoutSeconds int, + extraBody map[string]any, + customHeaders map[string]string, +) *GeminiProvider { + if strings.TrimSpace(apiBase) == "" { + apiBase = geminiDefaultAPIBase + } + client := common.NewHTTPClient(proxy) + if requestTimeoutSeconds > 0 { + client.Timeout = time.Duration(requestTimeoutSeconds) * time.Second + } + + return &GeminiProvider{ + apiKey: strings.TrimSpace(apiKey), + apiBase: strings.TrimRight(strings.TrimSpace(apiBase), "/"), + httpClient: client, + extraBody: cloneAnyMap(extraBody), + customHeaders: cloneStringMap(customHeaders), + userAgent: strings.TrimSpace(userAgent), + } +} + +func (p *GeminiProvider) GetDefaultModel() string { + return geminiDefaultModel +} + +func (p *GeminiProvider) SupportsThinking() bool { + return true +} + +func (p *GeminiProvider) Chat( + ctx context.Context, + messages []Message, + tools []ToolDefinition, + model string, + options map[string]any, +) (*LLMResponse, error) { + if p.apiBase == "" { + return nil, fmt.Errorf("API base not configured") + } + + model = normalizeGeminiModel(model) + requestBody := p.buildRequestBody(messages, tools, model, options) + jsonData, err := json.Marshal(requestBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + url := fmt.Sprintf("%s/models/%s:generateContent", p.apiBase, model) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + p.applyHeaders(req) + + resp, err := p.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, common.HandleErrorResponse(resp, p.apiBase) + } + + var apiResp geminiGenerateContentResponse + if err := json.NewDecoder(resp.Body).Decode(&apiResp); err != nil { + return nil, fmt.Errorf("failed to decode response: %w", err) + } + + return parseGeminiResponse(&apiResp), nil +} + +func (p *GeminiProvider) ChatStream( + ctx context.Context, + messages []Message, + tools []ToolDefinition, + model string, + options map[string]any, + onChunk func(accumulated string), +) (*LLMResponse, error) { + if p.apiBase == "" { + return nil, fmt.Errorf("API base not configured") + } + + model = normalizeGeminiModel(model) + requestBody := p.buildRequestBody(messages, tools, model, options) + jsonData, err := json.Marshal(requestBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + url := fmt.Sprintf("%s/models/%s:streamGenerateContent?alt=sse", p.apiBase, model) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + p.applyHeaders(req) + req.Header.Set("Accept", "text/event-stream") + + // Streaming should not use a whole-request timeout; context cancellation is the guard. + streamClient := &http.Client{Transport: p.httpClient.Transport} + resp, err := streamClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, common.HandleErrorResponse(resp, p.apiBase) + } + + return parseGeminiStreamResponse(ctx, resp.Body, onChunk) +} + +func (p *GeminiProvider) applyHeaders(req *http.Request) { + req.Header.Set("Content-Type", "application/json") + if p.apiKey != "" { + req.Header.Set("X-Goog-Api-Key", p.apiKey) + } + if p.userAgent != "" { + req.Header.Set("User-Agent", p.userAgent) + } + for k, v := range p.customHeaders { + if strings.TrimSpace(k) == "" { + continue + } + req.Header.Set(k, v) + } +} + +func (p *GeminiProvider) buildRequestBody( + messages []Message, + tools []ToolDefinition, + model string, + options map[string]any, +) map[string]any { + contents := make([]geminiContent, 0, len(messages)) + toolCallNames := make(map[string]string) + systemPrompts := make([]string, 0, 1) + + for _, msg := range messages { + switch msg.Role { + case "system": + if strings.TrimSpace(msg.Content) != "" { + systemPrompts = append(systemPrompts, msg.Content) + } + + case "user": + if msg.ToolCallID != "" { + toolName := resolveToolResponseName(msg.ToolCallID, toolCallNames) + contents = append(contents, geminiContent{ + Role: "user", + Parts: []geminiPart{{ + FunctionResponse: buildGeminiFunctionResponse(toolName, msg.ToolCallID, msg.Content, msg.Media), + }}, + }) + continue + } + + parts := make([]geminiPart, 0, 1+len(msg.Media)) + if strings.TrimSpace(msg.Content) != "" { + parts = append(parts, geminiPart{Text: msg.Content}) + } + parts = append(parts, buildInlineMediaParts(msg.Media)...) + if len(parts) > 0 { + contents = append(contents, geminiContent{Role: "user", Parts: parts}) + } + + case "assistant": + content := geminiContent{Role: "model"} + if strings.TrimSpace(msg.Content) != "" { + content.Parts = append(content.Parts, geminiPart{Text: msg.Content}) + } + for _, tc := range msg.ToolCalls { + toolName, toolArgs, thoughtSignature := normalizeStoredToolCall(tc) + if toolName == "" { + continue + } + if tc.ID != "" { + toolCallNames[tc.ID] = toolName + } + part := geminiPart{ + FunctionCall: &geminiFunctionCall{ + Name: toolName, + Args: toolArgs, + ID: tc.ID, + }, + } + if thoughtSignature != "" { + part.ThoughtSignature = thoughtSignature + } + content.Parts = append(content.Parts, part) + } + if len(content.Parts) > 0 { + contents = append(contents, content) + } + + case "tool": + toolName := resolveToolResponseName(msg.ToolCallID, toolCallNames) + contents = append(contents, geminiContent{ + Role: "user", + Parts: []geminiPart{{ + FunctionResponse: buildGeminiFunctionResponse(toolName, msg.ToolCallID, msg.Content, msg.Media), + }}, + }) + } + } + + body := map[string]any{ + "contents": contents, + } + if len(systemPrompts) > 0 { + systemParts := make([]geminiPart, 0, len(systemPrompts)) + for _, prompt := range systemPrompts { + systemParts = append(systemParts, geminiPart{Text: prompt}) + } + body["systemInstruction"] = &geminiContent{Parts: systemParts} + } + + if len(tools) > 0 { + funcDecls := make([]geminiFunctionDeclaration, 0, len(tools)) + for _, t := range tools { + if t.Type != "function" { + continue + } + funcDecls = append(funcDecls, geminiFunctionDeclaration{ + Name: t.Function.Name, + Description: t.Function.Description, + Parameters: sanitizeSchemaForGemini(t.Function.Parameters), + }) + } + if len(funcDecls) > 0 { + body["tools"] = []geminiTool{{FunctionDeclarations: funcDecls}} + } + } + + generationConfig := make(map[string]any) + if val, ok := options["max_tokens"]; ok { + if maxTokens, ok := val.(int); ok && maxTokens > 0 { + generationConfig["maxOutputTokens"] = maxTokens + } else if maxTokens, ok := val.(float64); ok && maxTokens > 0 { + generationConfig["maxOutputTokens"] = int(maxTokens) + } + } + if temp, ok := options["temperature"].(float64); ok { + generationConfig["temperature"] = temp + } + + if thinkingConfig := buildGeminiThinkingConfig(model, options); len(thinkingConfig) > 0 { + generationConfig["thinkingConfig"] = thinkingConfig + } + + if len(generationConfig) > 0 { + body["generationConfig"] = generationConfig + } + + for k, v := range p.extraBody { + body[k] = v + } + + return body +} + +func normalizeGeminiModel(model string) string { + model = strings.TrimSpace(model) + model = strings.TrimPrefix(model, "models/") + if strings.Contains(model, "/") { + _, modelID := ExtractProtocol(model) + if modelID != "" { + return modelID + } + } + if model == "" { + return geminiDefaultModel + } + return model +} + +func mapGeminiThinkingLevel(level string) string { + switch strings.ToLower(strings.TrimSpace(level)) { + case "minimal", "off": + return "minimal" + case "low": + return "low" + case "medium": + return "medium" + case "high", "xhigh", "adaptive": + return "high" + default: + return "" + } +} + +func buildGeminiThinkingConfig(model string, options map[string]any) map[string]any { + if !geminiModelSupportsThinkingConfig(model) { + return nil + } + + config := map[string]any{} + rawLevel, _ := options["thinking_level"].(string) + rawLevel = strings.ToLower(strings.TrimSpace(rawLevel)) + if rawLevel == "" { + // Align with agent-level default: unset means ThinkingOff. + rawLevel = "off" + } + + includeThoughts := rawLevel != "off" && rawLevel != "minimal" + config["includeThoughts"] = includeThoughts + + if isGemini25Model(model) { + if isGemini25ProModel(model) && (rawLevel == "off" || rawLevel == "minimal") { + // Gemini 2.5 Pro cannot disable thinking; keep model-default thinking. + return config + } + if budget, ok := mapGeminiThinkingBudget(rawLevel); ok { + config["thinkingBudget"] = budget + } + return config + } + + if isGemini3ProModel(model) && (rawLevel == "off" || rawLevel == "minimal") { + // Gemini 3.x Pro does not support minimal thinking level. + return config + } + + if thinkingLevel := mapGeminiThinkingLevel(rawLevel); thinkingLevel != "" { + config["thinkingLevel"] = thinkingLevel + } + return config +} + +func geminiModelSupportsThinkingConfig(model string) bool { + lowerModel := strings.ToLower(strings.TrimSpace(model)) + return strings.Contains(lowerModel, "gemini-3") || isGemini25Model(lowerModel) +} + +func isGemini25Model(model string) bool { + lowerModel := strings.ToLower(strings.TrimSpace(model)) + return strings.Contains(lowerModel, "gemini-2.5") || strings.Contains(lowerModel, "gemini-25") +} + +func isGemini25ProModel(model string) bool { + lowerModel := strings.ToLower(strings.TrimSpace(model)) + return isGemini25Model(lowerModel) && strings.Contains(lowerModel, "pro") +} + +func isGemini3ProModel(model string) bool { + lowerModel := strings.ToLower(strings.TrimSpace(model)) + return strings.Contains(lowerModel, "gemini-3") && strings.Contains(lowerModel, "pro") +} + +func mapGeminiThinkingBudget(level string) (int, bool) { + level = strings.ToLower(strings.TrimSpace(level)) + if level == "" { + return 0, false + } + + switch level { + case "adaptive": + return -1, true + case "minimal": + return 0, true + case "off": + return 0, true + case "low": + return 1024, true + case "medium": + return 4096, true + case "high": + return 8192, true + case "xhigh": + return 16384, true + default: + return 0, false + } +} + +func parseGeminiResponse(resp *geminiGenerateContentResponse) *LLMResponse { + contentParts := make([]string, 0) + reasoningParts := make([]string, 0) + toolCalls := make([]ToolCall, 0) + finishReason := "" + + for _, candidate := range resp.Candidates { + for _, part := range candidate.Content.Parts { + if part.Text != "" { + if part.Thought { + reasoningParts = append(reasoningParts, part.Text) + } else { + contentParts = append(contentParts, part.Text) + } + } + if part.FunctionCall != nil { + toolCalls = append(toolCalls, buildGeminiToolCall(part)) + } + } + if candidate.FinishReason != "" { + finishReason = candidate.FinishReason + } + } + + var usage *UsageInfo + if resp.UsageMetadata.TotalTokenCount > 0 { + usage = &UsageInfo{ + PromptTokens: resp.UsageMetadata.PromptTokenCount, + CompletionTokens: resp.UsageMetadata.CandidatesTokenCount, + TotalTokens: resp.UsageMetadata.TotalTokenCount, + } + } + + return &LLMResponse{ + Content: strings.Join(contentParts, ""), + ReasoningContent: strings.Join(reasoningParts, ""), + ToolCalls: toolCalls, + FinishReason: normalizeGeminiFinishReason(finishReason, len(toolCalls)), + Usage: usage, + } +} + +func parseGeminiStreamResponse( + ctx context.Context, + reader io.Reader, + onChunk func(accumulated string), +) (*LLMResponse, error) { + var contentBuilder strings.Builder + var reasoningBuilder strings.Builder + var finishReason string + var usage *UsageInfo + + toolCallsByID := make(map[string]ToolCall) + toolCallOrder := make([]string, 0) + fallbackIndex := 0 + + scanner := bufio.NewScanner(reader) + scanner.Buffer(make([]byte, 0, 1024*1024), 10*1024*1024) + for scanner.Scan() { + if err := ctx.Err(); err != nil { + return nil, err + } + + line := scanner.Text() + if !strings.HasPrefix(line, "data: ") { + continue + } + data := strings.TrimSpace(strings.TrimPrefix(line, "data: ")) + if data == "" { + continue + } + if data == "[DONE]" { + break + } + + var chunk geminiGenerateContentResponse + if err := json.Unmarshal([]byte(data), &chunk); err != nil { + return nil, fmt.Errorf("invalid gemini stream chunk: %w", err) + } + + for _, candidate := range chunk.Candidates { + for _, part := range candidate.Content.Parts { + if part.Text != "" { + if part.Thought { + reasoningBuilder.WriteString(part.Text) + } else { + contentBuilder.WriteString(part.Text) + if onChunk != nil { + onChunk(contentBuilder.String()) + } + } + } + if part.FunctionCall != nil { + tc := buildGeminiToolCall(part) + if strings.TrimSpace(tc.Name) == "" { + continue + } + + key := strings.TrimSpace(part.FunctionCall.ID) + if key == "" { + if len(toolCallOrder) > 0 { + lastKey := toolCallOrder[len(toolCallOrder)-1] + if lastTC, exists := toolCallsByID[lastKey]; exists && lastTC.Name == tc.Name { + key = lastKey + } + } + if key == "" { + fallbackIndex++ + key = fmt.Sprintf("%s#%d", tc.Name, fallbackIndex) + } + } + + tc.ID = key + if _, exists := toolCallsByID[key]; !exists { + toolCallOrder = append(toolCallOrder, key) + } + toolCallsByID[key] = tc + } + } + if candidate.FinishReason != "" { + finishReason = candidate.FinishReason + } + } + + if chunk.UsageMetadata.TotalTokenCount > 0 { + usage = &UsageInfo{ + PromptTokens: chunk.UsageMetadata.PromptTokenCount, + CompletionTokens: chunk.UsageMetadata.CandidatesTokenCount, + TotalTokens: chunk.UsageMetadata.TotalTokenCount, + } + } + } + + if err := scanner.Err(); err != nil { + return nil, fmt.Errorf("streaming read error: %w", err) + } + + toolCalls := make([]ToolCall, 0, len(toolCallOrder)) + for _, key := range toolCallOrder { + toolCalls = append(toolCalls, toolCallsByID[key]) + } + + return &LLMResponse{ + Content: contentBuilder.String(), + ReasoningContent: reasoningBuilder.String(), + ToolCalls: toolCalls, + FinishReason: normalizeGeminiFinishReason(finishReason, len(toolCalls)), + Usage: usage, + }, nil +} + +func normalizeGeminiFinishReason(reason string, toolCalls int) string { + if toolCalls > 0 { + return "tool_calls" + } + + switch strings.ToUpper(strings.TrimSpace(reason)) { + case "MAX_TOKENS": + return "length" + case "", "STOP": + return "stop" + default: + return strings.ToLower(strings.TrimSpace(reason)) + } +} + +func buildGeminiToolCall(part geminiPart) ToolCall { + if part.FunctionCall == nil { + return ToolCall{} + } + + args := part.FunctionCall.Args + if args == nil { + args = make(map[string]any) + } + argsJSON, _ := json.Marshal(args) + thoughtSignature := extractPartThoughtSignature(part.ThoughtSignature, part.ThoughtSignatureSnake) + + toolCall := ToolCall{ + ID: part.FunctionCall.ID, + Name: part.FunctionCall.Name, + Arguments: args, + ThoughtSignature: thoughtSignature, + Function: &FunctionCall{ + Name: part.FunctionCall.Name, + Arguments: string(argsJSON), + ThoughtSignature: thoughtSignature, + }, + } + + if thoughtSignature != "" { + toolCall.ExtraContent = &ExtraContent{ + Google: &GoogleExtra{ThoughtSignature: thoughtSignature}, + } + } + if strings.TrimSpace(toolCall.ID) == "" { + toolCall.ID = fmt.Sprintf("call_%s_%d", toolCall.Name, time.Now().UnixNano()) + } + + return toolCall +} + +func buildInlineMediaParts(media []string) []geminiPart { + parts := make([]geminiPart, 0, len(media)) + for _, mediaURL := range media { + mimeType, data, ok := parseBase64DataURL(mediaURL) + if !ok { + continue + } + parts = append(parts, geminiPart{ + InlineData: &geminiInlineData{ + MIMEType: mimeType, + Data: data, + }, + }) + } + return parts +} + +func buildGeminiFunctionResponse( + toolName string, + toolCallID string, + result string, + media []string, +) *geminiFunctionResponse { + response := &geminiFunctionResponse{ + ID: toolCallID, + Name: toolName, + Response: map[string]any{ + "result": result, + }, + } + + if parts := buildFunctionResponseMediaParts(media); len(parts) > 0 { + response.Parts = parts + } + + return response +} + +func buildFunctionResponseMediaParts(media []string) []geminiFunctionResponsePart { + parts := make([]geminiFunctionResponsePart, 0, len(media)) + for i, mediaURL := range media { + mimeType, data, ok := parseBase64DataURL(mediaURL) + if !ok { + continue + } + parts = append(parts, geminiFunctionResponsePart{ + InlineData: &geminiInlineData{ + MIMEType: mimeType, + Data: data, + DisplayName: defaultFunctionResponseDisplayName(mimeType, i+1), + }, + }) + } + return parts +} + +func defaultFunctionResponseDisplayName(mimeType string, index int) string { + suffix := "bin" + switch strings.ToLower(strings.TrimSpace(mimeType)) { + case "image/png": + suffix = "png" + case "image/jpeg": + suffix = "jpg" + case "image/webp": + suffix = "webp" + case "application/pdf": + suffix = "pdf" + case "text/plain": + suffix = "txt" + } + return fmt.Sprintf("attachment-%d.%s", index, suffix) +} + +func parseBase64DataURL(mediaURL string) (mimeType string, data string, ok bool) { + if !strings.HasPrefix(mediaURL, "data:") { + return "", "", false + } + + payload := strings.TrimPrefix(mediaURL, "data:") + header, data, found := strings.Cut(payload, ",") + if !found { + return "", "", false + } + mimeType, params, _ := strings.Cut(header, ";") + mimeType = strings.TrimSpace(mimeType) + data = strings.TrimSpace(data) + if mimeType == "" || data == "" { + return "", "", false + } + if !strings.Contains(strings.ToLower(params), "base64") { + return "", "", false + } + return mimeType, data, true +} + +func cloneAnyMap(in map[string]any) map[string]any { + if len(in) == 0 { + return nil + } + out := make(map[string]any, len(in)) + for k, v := range in { + out[k] = v + } + return out +} + +func cloneStringMap(in map[string]string) map[string]string { + if len(in) == 0 { + return nil + } + out := make(map[string]string, len(in)) + for k, v := range in { + out[k] = v + } + return out +} + +type geminiGenerateContentResponse struct { + Candidates []struct { + Content struct { + Role string `json:"role"` + Parts []geminiPart `json:"parts"` + } `json:"content"` + FinishReason string `json:"finishReason"` + } `json:"candidates"` + UsageMetadata struct { + PromptTokenCount int `json:"promptTokenCount"` + CandidatesTokenCount int `json:"candidatesTokenCount"` + TotalTokenCount int `json:"totalTokenCount"` + } `json:"usageMetadata"` +} + +type geminiContent struct { + Role string `json:"role,omitempty"` + Parts []geminiPart `json:"parts"` +} + +type geminiPart struct { + Text string `json:"text,omitempty"` + Thought bool `json:"thought,omitempty"` + ThoughtSignature string `json:"thoughtSignature,omitempty"` + ThoughtSignatureSnake string `json:"thought_signature,omitempty"` + InlineData *geminiInlineData `json:"inlineData,omitempty"` + FunctionCall *geminiFunctionCall `json:"functionCall,omitempty"` + FunctionResponse *geminiFunctionResponse `json:"functionResponse,omitempty"` +} + +type geminiInlineData struct { + MIMEType string `json:"mimeType"` + Data string `json:"data"` + DisplayName string `json:"displayName,omitempty"` +} + +type geminiFunctionCall struct { + ID string `json:"id,omitempty"` + Name string `json:"name"` + Args map[string]any `json:"args,omitempty"` +} + +type geminiFunctionResponse struct { + ID string `json:"id,omitempty"` + Name string `json:"name"` + Response map[string]any `json:"response"` + Parts []geminiFunctionResponsePart `json:"parts,omitempty"` +} + +type geminiFunctionResponsePart struct { + InlineData *geminiInlineData `json:"inlineData,omitempty"` +} + +type geminiTool struct { + FunctionDeclarations []geminiFunctionDeclaration `json:"functionDeclarations"` +} + +type geminiFunctionDeclaration struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + Parameters any `json:"parameters,omitempty"` +} diff --git a/pkg/providers/gemini_provider_test.go b/pkg/providers/gemini_provider_test.go new file mode 100644 index 000000000..a0ab748eb --- /dev/null +++ b/pkg/providers/gemini_provider_test.go @@ -0,0 +1,763 @@ +package providers + +import ( + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestGeminiProvider_ChatSeparatesThoughtAndToolCall(t *testing.T) { + var capturedBody map[string]any + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + t.Fatalf("method = %s, want POST", r.Method) + } + if !strings.Contains(r.URL.Path, ":generateContent") { + t.Fatalf("path = %s, expected generateContent endpoint", r.URL.Path) + } + if got := r.Header.Get("X-Goog-Api-Key"); got != "test-key" { + t.Fatalf("X-Goog-Api-Key = %q, want %q", got, "test-key") + } + if err := json.NewDecoder(r.Body).Decode(&capturedBody); err != nil { + t.Fatalf("decode request body: %v", err) + } + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "candidates": []any{ + map[string]any{ + "content": map[string]any{ + "role": "model", + "parts": []any{ + map[string]any{"text": "hidden", "thought": true}, + map[string]any{"text": "visible"}, + map[string]any{ + "functionCall": map[string]any{ + "id": "call_1", + "name": "search", + "args": map[string]any{"q": "hi"}, + }, + "thoughtSignature": "sig-1", + }, + }, + }, + "finishReason": "STOP", + }, + }, + "usageMetadata": map[string]any{ + "promptTokenCount": 2, + "candidatesTokenCount": 3, + "totalTokenCount": 5, + }, + }) + })) + defer server.Close() + + provider := NewGeminiProvider("test-key", server.URL, "", "picoclaw-test", 0, nil, nil) + resp, err := provider.Chat( + t.Context(), + []Message{{Role: "user", Content: "hello"}}, + nil, + "gemini-3-flash-preview", + map[string]any{"thinking_level": "high"}, + ) + if err != nil { + t.Fatalf("Chat() error = %v", err) + } + if resp.Content != "visible" { + t.Fatalf("Content = %q, want %q", resp.Content, "visible") + } + if resp.ReasoningContent != "hidden" { + t.Fatalf("ReasoningContent = %q, want %q", resp.ReasoningContent, "hidden") + } + if resp.FinishReason != "tool_calls" { + t.Fatalf("FinishReason = %q, want %q", resp.FinishReason, "tool_calls") + } + if resp.Usage == nil || resp.Usage.TotalTokens != 5 { + t.Fatalf("Usage = %#v, expected total tokens = 5", resp.Usage) + } + if len(resp.ToolCalls) != 1 { + t.Fatalf("ToolCalls len = %d, want 1", len(resp.ToolCalls)) + } + if resp.ToolCalls[0].ID != "call_1" { + t.Fatalf("ToolCall ID = %q, want %q", resp.ToolCalls[0].ID, "call_1") + } + if resp.ToolCalls[0].Name != "search" { + t.Fatalf("ToolCall Name = %q, want %q", resp.ToolCalls[0].Name, "search") + } + if resp.ToolCalls[0].ThoughtSignature != "sig-1" { + t.Fatalf("ToolCall ThoughtSignature = %q, want %q", resp.ToolCalls[0].ThoughtSignature, "sig-1") + } + if resp.ToolCalls[0].Function == nil || !strings.Contains(resp.ToolCalls[0].Function.Arguments, `"q":"hi"`) { + t.Fatalf("ToolCall Function arguments = %#v, want q=hi", resp.ToolCalls[0].Function) + } + + generationConfig, ok := capturedBody["generationConfig"].(map[string]any) + if !ok { + t.Fatalf("request missing generationConfig: %#v", capturedBody) + } + thinkingConfig, ok := generationConfig["thinkingConfig"].(map[string]any) + if !ok { + t.Fatalf("request missing thinkingConfig: %#v", generationConfig) + } + if includeThoughts, ok := thinkingConfig["includeThoughts"].(bool); !ok || !includeThoughts { + t.Fatalf("thinkingConfig.includeThoughts = %#v, want true", thinkingConfig["includeThoughts"]) + } + if got := thinkingConfig["thinkingLevel"]; got != "high" { + t.Fatalf("thinkingConfig.thinkingLevel = %#v, want %q", got, "high") + } +} + +func TestGeminiProvider_ChatStreamParsesThoughtTextAndToolCalls(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !strings.Contains(r.URL.Path, ":streamGenerateContent") { + t.Fatalf("path = %s, expected streamGenerateContent endpoint", r.URL.Path) + } + if got := r.URL.Query().Get("alt"); got != "sse" { + t.Fatalf("alt query = %q, want %q", got, "sse") + } + + w.Header().Set("Content-Type", "text/event-stream") + flusher, ok := w.(http.Flusher) + if !ok { + t.Fatal("response writer is not flushable") + } + + chunks := []map[string]any{ + { + "candidates": []any{map[string]any{ + "content": map[string]any{ + "parts": []any{ + map[string]any{"text": "think ", "thought": true}, + map[string]any{"text": "Hello "}, + }, + }, + }}, + }, + { + "candidates": []any{map[string]any{ + "content": map[string]any{ + "parts": []any{ + map[string]any{"text": "World"}, + map[string]any{ + "functionCall": map[string]any{ + "id": "call_stream", + "name": "search", + "args": map[string]any{"q": "stream"}, + }, + }, + }, + }, + "finishReason": "STOP", + }}, + "usageMetadata": map[string]any{ + "promptTokenCount": 1, + "candidatesTokenCount": 2, + "totalTokenCount": 3, + }, + }, + } + + for _, chunk := range chunks { + raw, err := json.Marshal(chunk) + if err != nil { + t.Fatalf("marshal chunk: %v", err) + } + if _, err := fmt.Fprintf(w, "data: %s\n\n", raw); err != nil { + t.Fatalf("write chunk: %v", err) + } + flusher.Flush() + } + _, _ = fmt.Fprint(w, "data: [DONE]\n\n") + flusher.Flush() + })) + defer server.Close() + + provider := NewGeminiProvider("test-key", server.URL, "", "", 0, nil, nil) + updates := make([]string, 0) + resp, err := provider.ChatStream( + t.Context(), + []Message{{Role: "user", Content: "hello"}}, + nil, + "gemini-2.5-flash", + nil, + func(accumulated string) { + updates = append(updates, accumulated) + }, + ) + if err != nil { + t.Fatalf("ChatStream() error = %v", err) + } + if resp.Content != "Hello World" { + t.Fatalf("Content = %q, want %q", resp.Content, "Hello World") + } + if resp.ReasoningContent != "think " { + t.Fatalf("ReasoningContent = %q, want %q", resp.ReasoningContent, "think ") + } + if len(resp.ToolCalls) != 1 || resp.ToolCalls[0].ID != "call_stream" { + t.Fatalf("ToolCalls = %#v, want single call_stream", resp.ToolCalls) + } + if resp.FinishReason != "tool_calls" { + t.Fatalf("FinishReason = %q, want %q", resp.FinishReason, "tool_calls") + } + if resp.Usage == nil || resp.Usage.TotalTokens != 3 { + t.Fatalf("Usage = %#v, expected total tokens = 3", resp.Usage) + } + if len(updates) < 2 || updates[len(updates)-1] != "Hello World" { + t.Fatalf("stream updates = %#v, expected final accumulated text", updates) + } +} + +func TestGeminiProvider_ChatStreamSkipsEmptyDataFrames(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + flusher, ok := w.(http.Flusher) + if !ok { + t.Fatal("response writer is not flushable") + } + + _, _ = fmt.Fprint(w, "data: \n\n") + flusher.Flush() + + chunk := map[string]any{ + "candidates": []any{map[string]any{ + "content": map[string]any{ + "parts": []any{map[string]any{"text": "ok"}}, + }, + "finishReason": "STOP", + }}, + } + raw, err := json.Marshal(chunk) + if err != nil { + t.Fatalf("marshal chunk: %v", err) + } + _, _ = fmt.Fprintf(w, "data: %s\n\n", raw) + flusher.Flush() + _, _ = fmt.Fprint(w, "data: [DONE]\n\n") + flusher.Flush() + })) + defer server.Close() + + provider := NewGeminiProvider("test-key", server.URL, "", "", 0, nil, nil) + resp, err := provider.ChatStream( + t.Context(), + []Message{{Role: "user", Content: "hello"}}, + nil, + "gemini-2.5-flash", + nil, + nil, + ) + if err != nil { + t.Fatalf("ChatStream() error = %v", err) + } + if resp.Content != "ok" { + t.Fatalf("Content = %q, want %q", resp.Content, "ok") + } +} + +func TestGeminiProvider_ChatStreamReturnsErrorOnInvalidDataFrame(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + flusher, ok := w.(http.Flusher) + if !ok { + t.Fatal("response writer is not flushable") + } + + _, _ = fmt.Fprint(w, "data: {invalid-json}\n\n") + flusher.Flush() + })) + defer server.Close() + + provider := NewGeminiProvider("test-key", server.URL, "", "", 0, nil, nil) + _, err := provider.ChatStream( + t.Context(), + []Message{{Role: "user", Content: "hello"}}, + nil, + "gemini-2.5-flash", + nil, + nil, + ) + if err == nil { + t.Fatal("ChatStream() expected error for invalid SSE data frame") + } + if !strings.Contains(err.Error(), "invalid gemini stream chunk") { + t.Fatalf("error = %v, want contains %q", err, "invalid gemini stream chunk") + } +} + +func TestGeminiProvider_BuildRequestBody_UsesCamelCaseThoughtSignatureOnly(t *testing.T) { + provider := NewGeminiProvider("test-key", "https://example.com/v1beta", "", "", 0, nil, nil) + + body := provider.buildRequestBody( + []Message{{ + Role: "assistant", + ToolCalls: []ToolCall{{ + ID: "call_1", + Name: "search", + Arguments: map[string]any{"q": "hello"}, + Function: &FunctionCall{ + Name: "search", + Arguments: `{"q":"hello"}`, + ThoughtSignature: "sig-1", + }, + }}, + }}, + nil, + "gemini-2.5-flash", + nil, + ) + + raw, err := json.Marshal(body) + if err != nil { + t.Fatalf("marshal request body: %v", err) + } + jsonBody := string(raw) + + if !strings.Contains(jsonBody, `"thoughtSignature":"sig-1"`) { + t.Fatalf("request body = %s, expected camelCase thoughtSignature", jsonBody) + } + if strings.Contains(jsonBody, `"thought_signature"`) { + t.Fatalf("request body = %s, unexpected snake_case thought_signature", jsonBody) + } +} + +func TestGeminiProvider_ChatStreamCoalescesToolCallWithoutWireID(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + flusher, ok := w.(http.Flusher) + if !ok { + t.Fatal("response writer is not flushable") + } + + chunks := []map[string]any{ + { + "candidates": []any{map[string]any{ + "content": map[string]any{ + "parts": []any{ + map[string]any{ + "functionCall": map[string]any{ + "name": "search", + "args": map[string]any{"q": "first"}, + }, + }, + }, + }, + }}, + }, + { + "candidates": []any{map[string]any{ + "content": map[string]any{ + "parts": []any{ + map[string]any{ + "functionCall": map[string]any{ + "name": "search", + "args": map[string]any{"q": "second"}, + }, + }, + }, + }, + "finishReason": "STOP", + }}, + }, + } + + for _, chunk := range chunks { + raw, err := json.Marshal(chunk) + if err != nil { + t.Fatalf("marshal chunk: %v", err) + } + if _, err := fmt.Fprintf(w, "data: %s\n\n", raw); err != nil { + t.Fatalf("write chunk: %v", err) + } + flusher.Flush() + } + _, _ = fmt.Fprint(w, "data: [DONE]\n\n") + flusher.Flush() + })) + defer server.Close() + + provider := NewGeminiProvider("test-key", server.URL, "", "", 0, nil, nil) + resp, err := provider.ChatStream( + t.Context(), + []Message{{Role: "user", Content: "hello"}}, + nil, + "gemini-2.5-flash", + nil, + nil, + ) + if err != nil { + t.Fatalf("ChatStream() error = %v", err) + } + if len(resp.ToolCalls) != 1 { + t.Fatalf("ToolCalls len = %d, want 1", len(resp.ToolCalls)) + } + tc := resp.ToolCalls[0] + if tc.ID != "search#1" { + t.Fatalf("ToolCall ID = %q, want %q", tc.ID, "search#1") + } + if tc.Name != "search" { + t.Fatalf("ToolCall Name = %q, want %q", tc.Name, "search") + } + if argQ, ok := tc.Arguments["q"].(string); !ok || argQ != "second" { + t.Fatalf("ToolCall Arguments = %#v, want q=second", tc.Arguments) + } + if resp.FinishReason != "tool_calls" { + t.Fatalf("FinishReason = %q, want %q", resp.FinishReason, "tool_calls") + } +} + +func TestGeminiProvider_BuildRequestBodyIncludesMediaAndThinkingConfig(t *testing.T) { + provider := NewGeminiProvider("test-key", "https://example.com/v1beta", "", "", 0, nil, nil) + + body := provider.buildRequestBody( + []Message{{ + Role: "user", + Content: "analyze attachments", + Media: []string{ + "data:application/pdf;base64,UEZERGF0YQ==", + "data:image/png;base64,aW1hZ2VEYXRh", + }, + }}, + nil, + "gemini-3-flash-preview", + map[string]any{ + "thinking_level": "low", + "max_tokens": 128, + "temperature": 0.2, + }, + ) + + contents, ok := body["contents"].([]geminiContent) + if !ok || len(contents) != 1 { + t.Fatalf("contents = %#v, want one gemini content", body["contents"]) + } + parts := contents[0].Parts + mimeSet := map[string]bool{} + for _, part := range parts { + if part.InlineData != nil { + mimeSet[part.InlineData.MIMEType] = true + } + } + if !mimeSet["application/pdf"] { + t.Fatalf("inline media missing application/pdf: %#v", parts) + } + if !mimeSet["image/png"] { + t.Fatalf("inline media missing image/png: %#v", parts) + } + + generationConfig, ok := body["generationConfig"].(map[string]any) + if !ok { + t.Fatalf("generationConfig = %#v, want map", body["generationConfig"]) + } + if got := generationConfig["maxOutputTokens"]; got != 128 { + t.Fatalf("maxOutputTokens = %#v, want 128", got) + } + if got := generationConfig["temperature"]; got != 0.2 { + t.Fatalf("temperature = %#v, want 0.2", got) + } + thinkingConfig, ok := generationConfig["thinkingConfig"].(map[string]any) + if !ok { + t.Fatalf("thinkingConfig = %#v, want map", generationConfig["thinkingConfig"]) + } + if includeThoughts, ok := thinkingConfig["includeThoughts"].(bool); !ok || !includeThoughts { + t.Fatalf("includeThoughts = %#v, want true", thinkingConfig["includeThoughts"]) + } + if got := thinkingConfig["thinkingLevel"]; got != "low" { + t.Fatalf("thinkingLevel = %#v, want %q", got, "low") + } +} + +func TestGeminiProvider_BuildRequestBody_UsesThinkingBudgetForGemini25(t *testing.T) { + provider := NewGeminiProvider("test-key", "https://example.com/v1beta", "", "", 0, nil, nil) + body := provider.buildRequestBody( + []Message{{Role: "user", Content: "hello"}}, + nil, + "gemini-2.5-flash", + map[string]any{"thinking_level": "medium"}, + ) + + generationConfig, ok := body["generationConfig"].(map[string]any) + if !ok { + t.Fatalf("generationConfig = %#v, want map", body["generationConfig"]) + } + thinkingConfig, ok := generationConfig["thinkingConfig"].(map[string]any) + if !ok { + t.Fatalf("thinkingConfig = %#v, want map", generationConfig["thinkingConfig"]) + } + if got := thinkingConfig["thinkingBudget"]; got != 4096 { + t.Fatalf("thinkingBudget = %#v, want 4096", got) + } + if _, hasLevel := thinkingConfig["thinkingLevel"]; hasLevel { + t.Fatalf("thinkingLevel should not be set for Gemini 2.5: %#v", thinkingConfig) + } +} + +func TestGeminiProvider_BuildRequestBody_OmitsThinkingConfigForGemini20(t *testing.T) { + provider := NewGeminiProvider("test-key", "https://example.com/v1beta", "", "", 0, nil, nil) + body := provider.buildRequestBody( + []Message{{Role: "user", Content: "hello"}}, + nil, + "gemini-2.0-flash-exp", + map[string]any{"thinking_level": "high"}, + ) + + if _, ok := body["generationConfig"]; ok { + t.Fatalf("generationConfig should be omitted for Gemini 2.0 when only thinking_level is set: %#v", body) + } +} + +func TestGeminiProvider_BuildRequestBody_DefaultsThinkingOffForGemini25(t *testing.T) { + provider := NewGeminiProvider("test-key", "https://example.com/v1beta", "", "", 0, nil, nil) + body := provider.buildRequestBody( + []Message{{Role: "user", Content: "hello"}}, + nil, + "gemini-2.5-flash", + nil, + ) + + generationConfig, ok := body["generationConfig"].(map[string]any) + if !ok { + t.Fatalf("generationConfig = %#v, want map", body["generationConfig"]) + } + thinkingConfig, ok := generationConfig["thinkingConfig"].(map[string]any) + if !ok { + t.Fatalf("thinkingConfig = %#v, want map", generationConfig["thinkingConfig"]) + } + if got := thinkingConfig["thinkingBudget"]; got != 0 { + t.Fatalf("thinkingBudget = %#v, want 0 for default/off", got) + } + if includeThoughts, ok := thinkingConfig["includeThoughts"].(bool); !ok || includeThoughts { + t.Fatalf("includeThoughts = %#v, want false for default/off", thinkingConfig["includeThoughts"]) + } +} + +func TestGeminiProvider_BuildRequestBody_DefaultsThinkingOffForGemini3(t *testing.T) { + provider := NewGeminiProvider("test-key", "https://example.com/v1beta", "", "", 0, nil, nil) + body := provider.buildRequestBody( + []Message{{Role: "user", Content: "hello"}}, + nil, + "gemini-3-flash-preview", + nil, + ) + + generationConfig, ok := body["generationConfig"].(map[string]any) + if !ok { + t.Fatalf("generationConfig = %#v, want map", body["generationConfig"]) + } + thinkingConfig, ok := generationConfig["thinkingConfig"].(map[string]any) + if !ok { + t.Fatalf("thinkingConfig = %#v, want map", generationConfig["thinkingConfig"]) + } + if got := thinkingConfig["thinkingLevel"]; got != "minimal" { + t.Fatalf("thinkingLevel = %#v, want minimal for default/off", got) + } + if includeThoughts, ok := thinkingConfig["includeThoughts"].(bool); !ok || includeThoughts { + t.Fatalf("includeThoughts = %#v, want false for default/off", thinkingConfig["includeThoughts"]) + } +} + +func TestGeminiProvider_BuildRequestBody_DefaultsThinkingOffForGemini25Pro(t *testing.T) { + provider := NewGeminiProvider("test-key", "https://example.com/v1beta", "", "", 0, nil, nil) + body := provider.buildRequestBody( + []Message{{Role: "user", Content: "hello"}}, + nil, + "gemini-2.5-pro", + nil, + ) + + generationConfig, ok := body["generationConfig"].(map[string]any) + if !ok { + t.Fatalf("generationConfig = %#v, want map", body["generationConfig"]) + } + thinkingConfig, ok := generationConfig["thinkingConfig"].(map[string]any) + if !ok { + t.Fatalf("thinkingConfig = %#v, want map", generationConfig["thinkingConfig"]) + } + if includeThoughts, ok := thinkingConfig["includeThoughts"].(bool); !ok || includeThoughts { + t.Fatalf("includeThoughts = %#v, want false for default/off", thinkingConfig["includeThoughts"]) + } + if _, hasBudget := thinkingConfig["thinkingBudget"]; hasBudget { + t.Fatalf("thinkingBudget should be omitted for Gemini 2.5 Pro default/off: %#v", thinkingConfig) + } +} + +func TestGeminiProvider_BuildRequestBody_DefaultsThinkingOffForGemini31Pro(t *testing.T) { + provider := NewGeminiProvider("test-key", "https://example.com/v1beta", "", "", 0, nil, nil) + body := provider.buildRequestBody( + []Message{{Role: "user", Content: "hello"}}, + nil, + "gemini-3.1-pro", + nil, + ) + + generationConfig, ok := body["generationConfig"].(map[string]any) + if !ok { + t.Fatalf("generationConfig = %#v, want map", body["generationConfig"]) + } + thinkingConfig, ok := generationConfig["thinkingConfig"].(map[string]any) + if !ok { + t.Fatalf("thinkingConfig = %#v, want map", generationConfig["thinkingConfig"]) + } + if includeThoughts, ok := thinkingConfig["includeThoughts"].(bool); !ok || includeThoughts { + t.Fatalf("includeThoughts = %#v, want false for default/off", thinkingConfig["includeThoughts"]) + } + if _, hasLevel := thinkingConfig["thinkingLevel"]; hasLevel { + t.Fatalf("thinkingLevel should be omitted for Gemini 3.1 Pro default/off: %#v", thinkingConfig) + } +} + +func TestGeminiProvider_BuildRequestBody_PreservesMultipleSystemMessages(t *testing.T) { + provider := NewGeminiProvider("test-key", "https://example.com/v1beta", "", "", 0, nil, nil) + body := provider.buildRequestBody( + []Message{ + {Role: "system", Content: "You are helpful."}, + {Role: "system", Content: "Be concise."}, + {Role: "user", Content: "hello"}, + }, + nil, + "gemini-3-flash-preview", + nil, + ) + + systemInstruction, ok := body["systemInstruction"].(*geminiContent) + if !ok || systemInstruction == nil { + t.Fatalf("systemInstruction = %#v, want *geminiContent", body["systemInstruction"]) + } + if len(systemInstruction.Parts) != 2 { + t.Fatalf("systemInstruction.Parts len = %d, want 2", len(systemInstruction.Parts)) + } + if systemInstruction.Parts[0].Text != "You are helpful." || systemInstruction.Parts[1].Text != "Be concise." { + t.Fatalf("systemInstruction.Parts = %#v, want ordered system prompts", systemInstruction.Parts) + } +} + +func TestGeminiProvider_BuildRequestBody_PreservesToolResponseMedia(t *testing.T) { + provider := NewGeminiProvider("test-key", "https://example.com/v1beta", "", "", 0, nil, nil) + body := provider.buildRequestBody( + []Message{ + { + Role: "assistant", + ToolCalls: []ToolCall{{ + ID: "call_1", + Name: "load_image", + Arguments: map[string]any{"path": "demo.png"}, + }}, + }, + { + Role: "tool", + ToolCallID: "call_1", + Content: "tool result", + Media: []string{ + "data:image/png;base64,aW1hZ2VEYXRh", + "data:application/pdf;base64,UEZERGF0YQ==", + }, + }, + }, + nil, + "gemini-3-flash-preview", + nil, + ) + + contents, ok := body["contents"].([]geminiContent) + if !ok || len(contents) != 2 { + t.Fatalf("contents = %#v, want two content entries", body["contents"]) + } + parts := contents[1].Parts + if len(parts) != 1 || parts[0].FunctionResponse == nil { + t.Fatalf("tool response part = %#v, want functionResponse", parts) + } + response := parts[0].FunctionResponse + if response.Name != "load_image" { + t.Fatalf("functionResponse.Name = %q, want %q", response.Name, "load_image") + } + if response.Response["result"] != "tool result" { + t.Fatalf("functionResponse.Response = %#v, want result=tool result", response.Response) + } + if len(response.Parts) != 2 { + t.Fatalf("functionResponse.Parts len = %d, want 2", len(response.Parts)) + } +} + +func TestGeminiProvider_ChatAllowsCustomAuthHeaderWithoutAPIKey(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if got := r.Header.Get("Authorization"); got != "Bearer test-token" { + t.Fatalf("Authorization = %q, want %q", got, "Bearer test-token") + } + if got := r.Header.Get("X-Goog-Api-Key"); got != "" { + t.Fatalf("X-Goog-Api-Key = %q, want empty", got) + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "candidates": []any{ + map[string]any{ + "content": map[string]any{ + "parts": []any{map[string]any{"text": "ok"}}, + }, + "finishReason": "STOP", + }, + }, + }) + })) + defer server.Close() + + provider := NewGeminiProvider( + "", + server.URL, + "", + "", + 0, + nil, + map[string]string{"Authorization": "Bearer test-token"}, + ) + + resp, err := provider.Chat( + t.Context(), + []Message{{Role: "user", Content: "hello"}}, + nil, + "gemini-2.5-flash", + nil, + ) + if err != nil { + t.Fatalf("Chat() error = %v", err) + } + if resp.Content != "ok" { + t.Fatalf("Content = %q, want %q", resp.Content, "ok") + } +} + +func TestGeminiProvider_ChatAllowsMissingAPIKeyForCustomAPIBase(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if got := r.Header.Get("X-Goog-Api-Key"); got != "" { + t.Fatalf("X-Goog-Api-Key = %q, want empty", got) + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "candidates": []any{ + map[string]any{ + "content": map[string]any{"parts": []any{map[string]any{"text": "ok"}}}, + "finishReason": "STOP", + }, + }, + }) + })) + defer server.Close() + + provider := NewGeminiProvider("", server.URL, "", "", 0, nil, nil) + resp, err := provider.Chat( + t.Context(), + []Message{{Role: "user", Content: "hello"}}, + nil, + "gemini-2.5-flash", + nil, + ) + if err != nil { + t.Fatalf("Chat() error = %v", err) + } + if resp.Content != "ok" { + t.Fatalf("Content = %q, want %q", resp.Content, "ok") + } +} diff --git a/pkg/providers/openai_compat/provider.go b/pkg/providers/openai_compat/provider.go index d25a0fce4..98a70cfd2 100644 --- a/pkg/providers/openai_compat/provider.go +++ b/pkg/providers/openai_compat/provider.go @@ -8,6 +8,7 @@ import ( "fmt" "io" "log" + "maps" "net/http" "net/url" "strings" @@ -181,9 +182,7 @@ func (p *Provider) buildRequestBody( // Merge extra body fields configured per-provider/model. // These are injected last so they take precedence over defaults. - for k, v := range p.extraBody { - requestBody[k] = v - } + maps.Copy(requestBody, p.extraBody) return requestBody } diff --git a/web/backend/api/session.go b/web/backend/api/session.go index ae580d9aa..9bb6055e2 100644 --- a/web/backend/api/session.go +++ b/web/backend/api/session.go @@ -281,6 +281,12 @@ func visibleSessionMessages(messages []providers.Message, toolFeedbackMaxArgsLen } case "assistant": + // Reasoning-only assistant messages are transient display artifacts and + // should not be restored from session history. + if assistantMessageTransientThought(msg) { + continue + } + toolSummaryMessages := visibleAssistantToolSummaryMessages(msg.ToolCalls, toolFeedbackMaxArgsLength) if len(toolSummaryMessages) > 0 { transcript = append(transcript, toolSummaryMessages...) @@ -309,6 +315,13 @@ func visibleSessionMessages(messages []providers.Message, toolFeedbackMaxArgsLen return transcript } +func assistantMessageTransientThought(msg providers.Message) bool { + return strings.TrimSpace(msg.Content) == "" && + strings.TrimSpace(msg.ReasoningContent) != "" && + len(msg.ToolCalls) == 0 && + len(msg.Media) == 0 +} + func assistantMessageInternalOnly(msg providers.Message) bool { return strings.TrimSpace(msg.Content) == handledToolResponseSummaryText } diff --git a/web/backend/api/session_test.go b/web/backend/api/session_test.go index 5d7620362..599921bfe 100644 --- a/web/backend/api/session_test.go +++ b/web/backend/api/session_test.go @@ -218,6 +218,59 @@ func TestHandleGetSession_JSONLStorage(t *testing.T) { } } +func TestHandleGetSession_OmitsTransientThoughtMessages(t *testing.T) { + configPath, cleanup := setupOAuthTestEnv(t) + defer cleanup() + + dir := sessionsTestDir(t, configPath) + store, err := memory.NewJSONLStore(dir) + if err != nil { + t.Fatalf("NewJSONLStore() error = %v", err) + } + + sessionKey := picoSessionPrefix + "detail-transient-thought" + for _, msg := range []providers.Message{ + {Role: "user", Content: "hello"}, + {Role: "assistant", ReasoningContent: "internal chain of thought"}, + {Role: "assistant", Content: "final visible answer"}, + } { + if err := store.AddFullMessage(nil, sessionKey, msg); err != nil { + t.Fatalf("AddFullMessage() error = %v", err) + } + } + + h := NewHandler(configPath) + mux := http.NewServeMux() + h.RegisterRoutes(mux) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/api/sessions/detail-transient-thought", nil) + mux.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String()) + } + + var resp struct { + Messages []struct { + Role string `json:"role"` + Content string `json:"content"` + } `json:"messages"` + } + if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil { + t.Fatalf("Unmarshal() error = %v", err) + } + if len(resp.Messages) != 2 { + t.Fatalf("len(resp.Messages) = %d, want 2", len(resp.Messages)) + } + if resp.Messages[0].Role != "user" || resp.Messages[0].Content != "hello" { + t.Fatalf("first message = %#v, want user/hello", resp.Messages[0]) + } + if resp.Messages[1].Role != "assistant" || resp.Messages[1].Content != "final visible answer" { + t.Fatalf("second message = %#v, want assistant/final visible answer", resp.Messages[1]) + } +} + func TestHandleGetSession_ReconstructsVisibleMessageToolOutput(t *testing.T) { configPath, cleanup := setupOAuthTestEnv(t) defer cleanup() diff --git a/web/frontend/src/components/chat/assistant-message.tsx b/web/frontend/src/components/chat/assistant-message.tsx index 9966226b2..8dcbe15a1 100644 --- a/web/frontend/src/components/chat/assistant-message.tsx +++ b/web/frontend/src/components/chat/assistant-message.tsx @@ -1,5 +1,6 @@ -import { IconCheck, IconCopy } from "@tabler/icons-react" +import { IconBrain, IconCheck, IconCopy } from "@tabler/icons-react" import { useState } from "react" +import { useTranslation } from "react-i18next" import ReactMarkdown from "react-markdown" import rehypeRaw from "rehype-raw" import rehypeSanitize from "rehype-sanitize" @@ -7,16 +8,20 @@ import remarkGfm from "remark-gfm" import { Button } from "@/components/ui/button" import { formatMessageTime } from "@/hooks/use-pico-chat" +import { cn } from "@/lib/utils" interface AssistantMessageProps { content: string + isThought?: boolean timestamp?: string | number } export function AssistantMessage({ content, + isThought = false, timestamp = "", }: AssistantMessageProps) { + const { t } = useTranslation() const [isCopied, setIsCopied] = useState(false) const formattedTimestamp = timestamp !== "" ? formatMessageTime(timestamp) : "" @@ -33,6 +38,12 @@ export function AssistantMessage({
PicoClaw + {isThought && ( + + + {t("chat.reasoningLabel")} + + )} {formattedTimestamp && ( <> @@ -42,8 +53,22 @@ export function AssistantMessage({
-
-
+
+
{isCopied ? ( diff --git a/web/frontend/src/components/chat/chat-page.tsx b/web/frontend/src/components/chat/chat-page.tsx index 38a0fc6b1..e8e07a801 100644 --- a/web/frontend/src/components/chat/chat-page.tsx +++ b/web/frontend/src/components/chat/chat-page.tsx @@ -247,6 +247,7 @@ export function ChatPage() { {msg.role === "assistant" ? ( ) : ( diff --git a/web/frontend/src/features/chat/history.ts b/web/frontend/src/features/chat/history.ts index 850b3319e..92beb06b7 100644 --- a/web/frontend/src/features/chat/history.ts +++ b/web/frontend/src/features/chat/history.ts @@ -24,6 +24,7 @@ export async function loadSessionMessages( id: `hist-${index}-${Date.now()}`, role: message.role, content: message.content, + kind: message.role === "assistant" ? "normal" : undefined, attachments: toChatAttachments(message.media), timestamp: fallbackTime, })) @@ -50,7 +51,7 @@ function messageSignature(message: ChatMessage): string { return `${message.role}\u0000${message.content}\u0000${normalizeMessageTimestamp( message.timestamp, - )}\u0000${attachmentSignature}` + )}\u0000${message.kind ?? ""}\u0000${attachmentSignature}` } function comparableTimestamp(timestamp: number | string): number { diff --git a/web/frontend/src/features/chat/protocol.ts b/web/frontend/src/features/chat/protocol.ts index 7429aef01..a7edfc21b 100644 --- a/web/frontend/src/features/chat/protocol.ts +++ b/web/frontend/src/features/chat/protocol.ts @@ -1,7 +1,10 @@ import { toast } from "sonner" import { normalizeUnixTimestamp } from "@/features/chat/state" -import { updateChatStore } from "@/store/chat" +import { + type AssistantMessageKind, + updateChatStore, +} from "@/store/chat" export interface PicoMessage { type: string @@ -11,6 +14,16 @@ export interface PicoMessage { payload?: Record } +function parseAssistantMessageKind( + payload: Record, +): AssistantMessageKind { + return payload.thought === true ? "thought" : "normal" +} + +function hasAssistantKindPayload(payload: Record): boolean { + return typeof payload.thought === "boolean" +} + export function handlePicoMessage( message: PicoMessage, expectedSessionId: string, @@ -25,6 +38,7 @@ export function handlePicoMessage( case "message.create": { const content = (payload.content as string) || "" const messageId = (payload.message_id as string) || `pico-${Date.now()}` + const kind = parseAssistantMessageKind(payload) const timestamp = message.timestamp !== undefined && Number.isFinite(Number(message.timestamp)) @@ -38,6 +52,7 @@ export function handlePicoMessage( id: messageId, role: "assistant", content, + kind, timestamp, }, ], @@ -49,13 +64,21 @@ export function handlePicoMessage( case "message.update": { const content = (payload.content as string) || "" const messageId = payload.message_id as string + const hasKind = hasAssistantKindPayload(payload) + const kind = parseAssistantMessageKind(payload) if (!messageId) { break } updateChatStore((prev) => ({ messages: prev.messages.map((msg) => - msg.id === messageId ? { ...msg, content } : msg, + msg.id === messageId + ? { + ...msg, + content, + ...(hasKind ? { kind } : {}), + } + : msg, ), })) break diff --git a/web/frontend/src/i18n/locales/en.json b/web/frontend/src/i18n/locales/en.json index b53abeb76..2434d4576 100644 --- a/web/frontend/src/i18n/locales/en.json +++ b/web/frontend/src/i18n/locales/en.json @@ -47,6 +47,7 @@ "step3": "Preparing response...", "step4": "Almost there..." }, + "reasoningLabel": "Reasoning", "history": "History", "noHistory": "No chat history yet", "historyLoadFailed": "Failed to load chat history", diff --git a/web/frontend/src/i18n/locales/zh.json b/web/frontend/src/i18n/locales/zh.json index e2e8eae04..c03d4181d 100644 --- a/web/frontend/src/i18n/locales/zh.json +++ b/web/frontend/src/i18n/locales/zh.json @@ -47,6 +47,7 @@ "step3": "准备回复...", "step4": "马上就好..." }, + "reasoningLabel": "思考", "history": "历史记录", "noHistory": "暂无对话历史", "historyLoadFailed": "加载历史记录失败", diff --git a/web/frontend/src/store/chat.ts b/web/frontend/src/store/chat.ts index 21eb5edff..2c6f70610 100644 --- a/web/frontend/src/store/chat.ts +++ b/web/frontend/src/store/chat.ts @@ -11,11 +11,14 @@ export interface ChatAttachment { filename?: string } +export type AssistantMessageKind = "normal" | "thought" + export interface ChatMessage { id: string role: "user" | "assistant" content: string timestamp: number | string + kind?: AssistantMessageKind attachments?: ChatAttachment[] }