Merge pull request #2475 from lc6464/fix/issue-2448-separate-thought-message

feat(gemini,pico): separate thought messages and add native Gemini provider
This commit is contained in:
daming大铭
2026-04-12 19:20:19 +08:00
committed by GitHub
22 changed files with 2004 additions and 30 deletions
+48 -7
View File
@@ -105,6 +105,8 @@ const (
toolLimitResponse = "I've reached `max_tool_iterations` without a final response. Increase `max_tool_iterations` in config.json if this task needs more tool steps."
handledToolResponseSummary = "Requested output delivered via tool attachment."
sessionKeyAgentPrefix = "agent:"
metadataKeyMessageKind = "message_kind"
messageKindThought = "thought"
metadataKeyAccountID = "account_id"
metadataKeyGuildID = "guild_id"
metadataKeyTeamID = "team_id"
@@ -1622,6 +1624,41 @@ func (al *AgentLoop) targetReasoningChannelID(channelName string) (chatID string
return ""
}
func (al *AgentLoop) publishPicoReasoning(ctx context.Context, reasoningContent, chatID string) {
if reasoningContent == "" || chatID == "" {
return
}
if ctx.Err() != nil {
return
}
pubCtx, pubCancel := context.WithTimeout(ctx, 5*time.Second)
defer pubCancel()
if err := al.bus.PublishOutbound(pubCtx, bus.OutboundMessage{
Channel: "pico",
ChatID: chatID,
Content: reasoningContent,
Metadata: map[string]string{
metadataKeyMessageKind: messageKindThought,
},
}); err != nil {
if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) ||
errors.Is(err, bus.ErrBusClosed) {
logger.DebugCF("agent", "Pico reasoning publish skipped (timeout/cancel)", map[string]any{
"channel": "pico",
"error": err.Error(),
})
} else {
logger.WarnCF("agent", "Failed to publish pico reasoning (best-effort)", map[string]any{
"channel": "pico",
"error": err.Error(),
})
}
}
}
func (al *AgentLoop) handleReasoning(
ctx context.Context,
reasoningContent, channelName, channelID string,
@@ -2223,12 +2260,16 @@ turnLoop:
if reasoningContent == "" {
reasoningContent = response.ReasoningContent
}
go al.handleReasoning(
turnCtx,
reasoningContent,
ts.channel,
al.targetReasoningChannelID(ts.channel),
)
if ts.channel == "pico" {
go al.publishPicoReasoning(turnCtx, reasoningContent, ts.chatID)
} else {
go al.handleReasoning(
turnCtx,
reasoningContent,
ts.channel,
al.targetReasoningChannelID(ts.channel),
)
}
al.emitEvent(
EventKindLLMResponse,
ts.eventMeta("runTurn", "turn.llm.response"),
@@ -2277,7 +2318,7 @@ turnLoop:
if len(response.ToolCalls) == 0 || gracefulTerminal {
responseContent := response.Content
if responseContent == "" && response.ReasoningContent != "" {
if responseContent == "" && response.ReasoningContent != "" && ts.channel != "pico" {
responseContent = response.ReasoningContent
}
if steerMsgs := al.dequeueSteeringMessagesForScope(ts.sessionKey); len(steerMsgs) > 0 {
+57 -1
View File
@@ -1921,7 +1921,7 @@ func TestProcessMessage_FallbackUsesPerCandidateProvider(t *testing.T) {
},
{
ModelName: "gemma-fallback",
Model: "gemini/gemma-3-27b-it",
Model: "openrouter/gemma-3-27b-it",
APIBase: fallbackServer.URL,
APIKeys: config.SimpleSecureStrings("fallback-key"),
Workspace: workspace,
@@ -2660,6 +2660,62 @@ func TestProcessMessage_PublishesReasoningContentToReasoningChannel(t *testing.T
}
}
func TestProcessMessage_PicoPublishesReasoningAsThoughtMessage(t *testing.T) {
tmpDir := t.TempDir()
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
ModelName: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
msgBus := bus.NewMessageBus()
provider := &reasoningContentProvider{
response: "final answer",
reasoningContent: "thinking trace",
}
al := NewAgentLoop(cfg, msgBus, provider)
response, err := al.processMessage(context.Background(), bus.InboundMessage{
Channel: "pico",
SenderID: "user1",
ChatID: "pico:test-session",
Content: "hello",
})
if err != nil {
t.Fatalf("processMessage() error = %v", err)
}
if response != "final answer" {
t.Fatalf("processMessage() response = %q, want %q", response, "final answer")
}
var thoughtMsg *bus.OutboundMessage
deadline := time.After(3 * time.Second)
for thoughtMsg == nil {
select {
case outbound := <-msgBus.OutboundChan():
msg := outbound
if msg.Content == "thinking trace" {
thoughtMsg = &msg
}
case <-deadline:
t.Fatal("expected thought outbound message for pico")
}
}
if thoughtMsg.Channel != "pico" || thoughtMsg.ChatID != "pico:test-session" {
t.Fatalf("thought message route = %s/%s, want pico/pico:test-session", thoughtMsg.Channel, thoughtMsg.ChatID)
}
if thoughtMsg.Metadata[metadataKeyMessageKind] != messageKindThought {
t.Fatalf("thought metadata kind = %q, want %q", thoughtMsg.Metadata[metadataKeyMessageKind], messageKindThought)
}
}
func TestProcessHeartbeat_DoesNotPublishToolFeedback(t *testing.T) {
tmpDir := t.TempDir()
heartbeatFile := filepath.Join(tmpDir, "heartbeat-task.txt")
+6 -2
View File
@@ -242,7 +242,11 @@ func (c *PicoClientChannel) handleInbound(pc *picoConn, msg PicoMessage) {
}
func (c *PicoClientChannel) handleServerMessage(pc *picoConn, msg PicoMessage) {
content, _ := msg.Payload["content"].(string)
if isThoughtPayload(msg.Payload) {
return
}
content, _ := msg.Payload[PayloadKeyContent].(string)
if strings.TrimSpace(content) == "" {
return
}
@@ -285,7 +289,7 @@ func (c *PicoClientChannel) Send(ctx context.Context, msg bus.OutboundMessage) (
}
outMsg := newMessage(TypeMessageSend, map[string]any{
"content": msg.Content,
PayloadKeyContent: msg.Content,
})
outMsg.SessionID = strings.TrimPrefix(msg.ChatID, "pico_client:")
return nil, pc.writeJSON(outMsg)
+64
View File
@@ -316,3 +316,67 @@ func TestPicoChannel_HandleMessageSend_AllowsMediaOnly(t *testing.T) {
t.Fatal("timed out waiting for inbound media message")
}
}
func TestIsThoughtPayload(t *testing.T) {
tests := []struct {
name string
payload map[string]any
want bool
}{
{
name: "explicit thought bool",
payload: map[string]any{PayloadKeyThought: true},
want: true,
},
{
name: "thought false",
payload: map[string]any{PayloadKeyThought: false},
want: false,
},
{
name: "thought string ignored",
payload: map[string]any{PayloadKeyThought: "true"},
want: false,
},
{
name: "default normal",
payload: map[string]any{PayloadKeyContent: "hello"},
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := isThoughtPayload(tt.payload); got != tt.want {
t.Fatalf("isThoughtPayload() = %v, want %v", got, tt.want)
}
})
}
}
func TestPicoClientChannel_HandleServerMessage_IgnoresThought(t *testing.T) {
mb := bus.NewMessageBus()
ch, err := NewPicoClientChannel(config.PicoClientConfig{
URL: "ws://localhost:8080/ws",
}, mb)
if err != nil {
t.Fatalf("NewPicoClientChannel() error = %v", err)
}
ch.ctx = context.Background()
pc := &picoConn{sessionID: "sess-thought"}
ch.handleServerMessage(pc, PicoMessage{
Type: TypeMessageCreate,
Payload: map[string]any{
PayloadKeyContent: "internal reasoning",
PayloadKeyThought: true,
},
})
select {
case msg := <-mb.InboundChan():
t.Fatalf("expected no inbound publish for thought payload, got %+v", msg)
case <-time.After(150 * time.Millisecond):
}
}
+13 -3
View File
@@ -39,6 +39,13 @@ var allowedInlineImageMIMETypes = map[string]struct{}{
"image/bmp": {},
}
func outboundMessageIsThought(metadata map[string]string) bool {
if len(metadata) == 0 {
return false
}
return strings.EqualFold(strings.TrimSpace(metadata["message_kind"]), MessageKindThought)
}
// writeJSON sends a JSON message to the connection with write locking.
func (pc *picoConn) writeJSON(v any) error {
if pc.closed.Load() {
@@ -247,9 +254,11 @@ func (c *PicoChannel) Send(ctx context.Context, msg bus.OutboundMessage) ([]stri
if !c.IsRunning() {
return nil, channels.ErrNotRunning
}
isThought := outboundMessageIsThought(msg.Metadata)
outMsg := newMessage(TypeMessageCreate, map[string]any{
"content": msg.Content,
PayloadKeyContent: msg.Content,
PayloadKeyThought: isThought,
})
return nil, c.broadcastToSession(msg.ChatID, outMsg)
@@ -288,8 +297,9 @@ func (c *PicoChannel) SendPlaceholder(ctx context.Context, chatID string) (strin
msgID := uuid.New().String()
outMsg := newMessage(TypeMessageCreate, map[string]any{
"content": text,
"message_id": msgID,
PayloadKeyContent: text,
PayloadKeyThought: false,
"message_id": msgID,
})
if err := c.broadcastToSession(chatID, outMsg); err != nil {
+10
View File
@@ -19,6 +19,11 @@ const (
TypePong = "pong"
PicoTokenPrefix = "pico-"
PayloadKeyContent = "content"
PayloadKeyThought = "thought"
MessageKindThought = "thought"
)
// PicoMessage is the wire format for all Pico Protocol messages.
@@ -39,6 +44,11 @@ func newMessage(msgType string, payload map[string]any) PicoMessage {
}
}
func isThoughtPayload(payload map[string]any) bool {
thought, _ := payload[PayloadKeyThought].(bool)
return thought
}
func newErrorWithPayload(code, message string, extra map[string]any) PicoMessage {
payload := map[string]any{
"code": code,
+12 -5
View File
@@ -389,6 +389,7 @@ type antigravityJSONResponse struct {
Content struct {
Parts []struct {
Text string `json:"text,omitempty"`
Thought bool `json:"thought,omitempty"`
ThoughtSignature string `json:"thoughtSignature,omitempty"`
ThoughtSignatureSnake string `json:"thought_signature,omitempty"`
FunctionCall *antigravityFunctionCall `json:"functionCall,omitempty"`
@@ -406,6 +407,7 @@ type antigravityJSONResponse struct {
func (p *AntigravityProvider) parseSSEResponse(body string) (*LLMResponse, error) {
var contentParts []string
var reasoningParts []string
var toolCalls []ToolCall
var usage *UsageInfo
var finishReason string
@@ -433,7 +435,11 @@ func (p *AntigravityProvider) parseSSEResponse(body string) (*LLMResponse, error
for _, candidate := range resp.Candidates {
for _, part := range candidate.Content.Parts {
if part.Text != "" {
contentParts = append(contentParts, part.Text)
if part.Thought {
reasoningParts = append(reasoningParts, part.Text)
} else {
contentParts = append(contentParts, part.Text)
}
}
if part.FunctionCall != nil {
argumentsJSON, _ := json.Marshal(part.FunctionCall.Args)
@@ -475,10 +481,11 @@ func (p *AntigravityProvider) parseSSEResponse(body string) (*LLMResponse, error
}
return &LLMResponse{
Content: strings.Join(contentParts, ""),
ToolCalls: toolCalls,
FinishReason: mappedFinish,
Usage: usage,
Content: strings.Join(contentParts, ""),
ReasoningContent: strings.Join(reasoningParts, ""),
ToolCalls: toolCalls,
FinishReason: mappedFinish,
Usage: usage,
}, nil
}
@@ -54,3 +54,27 @@ func TestResolveToolResponseNameInfersNameFromGeneratedCallID(t *testing.T) {
t.Fatalf("expected inferred tool name search_docs, got %q", got)
}
}
func TestParseSSEResponse_SplitsThoughtAndVisibleContent(t *testing.T) {
p := &AntigravityProvider{}
body := "data: {\"response\":{\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"hidden reasoning\",\"thought\":true},{\"text\":\"visible answer\"}],\"role\":\"model\"},\"finishReason\":\"STOP\"}],\"usageMetadata\":{\"promptTokenCount\":8,\"candidatesTokenCount\":17,\"totalTokenCount\":216}}}\n" +
"data: [DONE]\n"
resp, err := p.parseSSEResponse(body)
if err != nil {
t.Fatalf("parseSSEResponse() error = %v", err)
}
if resp.Content != "visible answer" {
t.Fatalf("Content = %q, want %q", resp.Content, "visible answer")
}
if resp.ReasoningContent != "hidden reasoning" {
t.Fatalf("ReasoningContent = %q, want %q", resp.ReasoningContent, "hidden reasoning")
}
if resp.FinishReason != "stop" {
t.Fatalf("FinishReason = %q, want %q", resp.FinishReason, "stop")
}
if resp.Usage == nil || resp.Usage.TotalTokens != 216 {
t.Fatalf("Usage.TotalTokens = %v, want %d", resp.Usage, 216)
}
}
+20 -2
View File
@@ -114,7 +114,7 @@ func ResolveAPIBase(cfg *config.ModelConfig) string {
// CreateProviderFromConfig creates a provider based on the ModelConfig.
// It uses the protocol prefix in the Model field to determine which provider to create.
// Supported protocol families include OpenAI-compatible prefixes (e.g., openai, openrouter, groq, gemini),
// Supported protocol families include OpenAI-compatible prefixes (e.g., openai, openrouter, groq),
// Azure OpenAI, Amazon Bedrock, Anthropic (including messages), and various CLI/compatibility shims.
// See the switch on protocol in this function for the authoritative list.
// Returns the provider, the model ID (without protocol prefix), and any error.
@@ -218,7 +218,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
}
return provider, modelID, nil
case "litellm", "lmstudio", "openrouter", "groq", "zhipu", "gemini", "nvidia", "venice",
case "litellm", "lmstudio", "openrouter", "groq", "zhipu", "nvidia", "venice",
"ollama", "moonshot", "shengsuanyun", "deepseek", "cerebras",
"vivgrid", "volcengine", "vllm", "qwen", "qwen-intl", "qwen-international", "dashscope-intl",
"qwen-us", "dashscope-us", "mistral", "avian", "longcat", "modelscope", "novita",
@@ -242,6 +242,24 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
cfg.CustomHeaders,
), modelID, nil
case "gemini":
if cfg.APIKey() == "" && cfg.APIBase == "" {
return nil, "", fmt.Errorf("api_key or api_base is required for gemini protocol (model: %s)", cfg.Model)
}
apiBase := cfg.APIBase
if apiBase == "" {
apiBase = getDefaultAPIBase(protocol)
}
return NewGeminiProvider(
cfg.APIKey(),
apiBase,
cfg.Proxy,
userAgent,
cfg.RequestTimeout,
cfg.ExtraBody,
cfg.CustomHeaders,
), modelID, nil
case "minimax":
// Minimax requires reasoning_split: true in the request body
if cfg.APIKey() == "" && cfg.APIBase == "" {
+56
View File
@@ -434,6 +434,62 @@ func TestCreateProviderFromConfig_Antigravity(t *testing.T) {
}
}
func TestCreateProviderFromConfig_Gemini(t *testing.T) {
cfg := &config.ModelConfig{
ModelName: "test-gemini",
Model: "gemini/gemini-2.5-flash",
}
cfg.SetAPIKey("test-key")
provider, modelID, err := CreateProviderFromConfig(cfg)
if err != nil {
t.Fatalf("CreateProviderFromConfig() error = %v", err)
}
if provider == nil {
t.Fatal("CreateProviderFromConfig() returned nil provider")
}
if modelID != "gemini-2.5-flash" {
t.Errorf("modelID = %q, want %q", modelID, "gemini-2.5-flash")
}
if _, ok := provider.(*GeminiProvider); !ok {
t.Fatalf("expected *GeminiProvider, got %T", provider)
}
}
func TestCreateProviderFromConfig_GeminiMissingAPIKey(t *testing.T) {
cfg := &config.ModelConfig{
ModelName: "test-gemini-no-key",
Model: "gemini/gemini-2.5-flash",
}
_, _, err := CreateProviderFromConfig(cfg)
if err == nil {
t.Fatal("CreateProviderFromConfig() expected error for missing gemini API key")
}
}
func TestCreateProviderFromConfig_GeminiCustomAPIBaseWithoutKey(t *testing.T) {
cfg := &config.ModelConfig{
ModelName: "test-gemini-custom-base",
Model: "gemini/gemini-2.5-flash",
APIBase: "https://proxy.example.com/v1beta",
}
provider, modelID, err := CreateProviderFromConfig(cfg)
if err != nil {
t.Fatalf("CreateProviderFromConfig() error = %v", err)
}
if provider == nil {
t.Fatal("CreateProviderFromConfig() returned nil provider")
}
if modelID != "gemini-2.5-flash" {
t.Errorf("modelID = %q, want %q", modelID, "gemini-2.5-flash")
}
if _, ok := provider.(*GeminiProvider); !ok {
t.Fatalf("expected *GeminiProvider, got %T", provider)
}
}
func TestCreateProviderFromConfig_ClaudeCLI(t *testing.T) {
cfg := &config.ModelConfig{
ModelName: "test-claude-cli",
+796
View File
@@ -0,0 +1,796 @@
package providers
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/sipeed/picoclaw/pkg/providers/common"
)
const (
geminiDefaultAPIBase = "https://generativelanguage.googleapis.com/v1beta"
geminiDefaultModel = "gemini-2.0-flash"
)
type GeminiProvider struct {
apiKey string
apiBase string
httpClient *http.Client
extraBody map[string]any
customHeaders map[string]string
userAgent string
}
func NewGeminiProvider(
apiKey string,
apiBase string,
proxy string,
userAgent string,
requestTimeoutSeconds int,
extraBody map[string]any,
customHeaders map[string]string,
) *GeminiProvider {
if strings.TrimSpace(apiBase) == "" {
apiBase = geminiDefaultAPIBase
}
client := common.NewHTTPClient(proxy)
if requestTimeoutSeconds > 0 {
client.Timeout = time.Duration(requestTimeoutSeconds) * time.Second
}
return &GeminiProvider{
apiKey: strings.TrimSpace(apiKey),
apiBase: strings.TrimRight(strings.TrimSpace(apiBase), "/"),
httpClient: client,
extraBody: cloneAnyMap(extraBody),
customHeaders: cloneStringMap(customHeaders),
userAgent: strings.TrimSpace(userAgent),
}
}
func (p *GeminiProvider) GetDefaultModel() string {
return geminiDefaultModel
}
func (p *GeminiProvider) SupportsThinking() bool {
return true
}
func (p *GeminiProvider) Chat(
ctx context.Context,
messages []Message,
tools []ToolDefinition,
model string,
options map[string]any,
) (*LLMResponse, error) {
if p.apiBase == "" {
return nil, fmt.Errorf("API base not configured")
}
model = normalizeGeminiModel(model)
requestBody := p.buildRequestBody(messages, tools, model, options)
jsonData, err := json.Marshal(requestBody)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
url := fmt.Sprintf("%s/models/%s:generateContent", p.apiBase, model)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(jsonData))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
p.applyHeaders(req)
resp, err := p.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, common.HandleErrorResponse(resp, p.apiBase)
}
var apiResp geminiGenerateContentResponse
if err := json.NewDecoder(resp.Body).Decode(&apiResp); err != nil {
return nil, fmt.Errorf("failed to decode response: %w", err)
}
return parseGeminiResponse(&apiResp), nil
}
func (p *GeminiProvider) ChatStream(
ctx context.Context,
messages []Message,
tools []ToolDefinition,
model string,
options map[string]any,
onChunk func(accumulated string),
) (*LLMResponse, error) {
if p.apiBase == "" {
return nil, fmt.Errorf("API base not configured")
}
model = normalizeGeminiModel(model)
requestBody := p.buildRequestBody(messages, tools, model, options)
jsonData, err := json.Marshal(requestBody)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
url := fmt.Sprintf("%s/models/%s:streamGenerateContent?alt=sse", p.apiBase, model)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(jsonData))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
p.applyHeaders(req)
req.Header.Set("Accept", "text/event-stream")
// Streaming should not use a whole-request timeout; context cancellation is the guard.
streamClient := &http.Client{Transport: p.httpClient.Transport}
resp, err := streamClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, common.HandleErrorResponse(resp, p.apiBase)
}
return parseGeminiStreamResponse(ctx, resp.Body, onChunk)
}
func (p *GeminiProvider) applyHeaders(req *http.Request) {
req.Header.Set("Content-Type", "application/json")
if p.apiKey != "" {
req.Header.Set("X-Goog-Api-Key", p.apiKey)
}
if p.userAgent != "" {
req.Header.Set("User-Agent", p.userAgent)
}
for k, v := range p.customHeaders {
if strings.TrimSpace(k) == "" {
continue
}
req.Header.Set(k, v)
}
}
func (p *GeminiProvider) buildRequestBody(
messages []Message,
tools []ToolDefinition,
model string,
options map[string]any,
) map[string]any {
contents := make([]geminiContent, 0, len(messages))
toolCallNames := make(map[string]string)
systemPrompts := make([]string, 0, 1)
for _, msg := range messages {
switch msg.Role {
case "system":
if strings.TrimSpace(msg.Content) != "" {
systemPrompts = append(systemPrompts, msg.Content)
}
case "user":
if msg.ToolCallID != "" {
toolName := resolveToolResponseName(msg.ToolCallID, toolCallNames)
contents = append(contents, geminiContent{
Role: "user",
Parts: []geminiPart{{
FunctionResponse: buildGeminiFunctionResponse(toolName, msg.ToolCallID, msg.Content, msg.Media),
}},
})
continue
}
parts := make([]geminiPart, 0, 1+len(msg.Media))
if strings.TrimSpace(msg.Content) != "" {
parts = append(parts, geminiPart{Text: msg.Content})
}
parts = append(parts, buildInlineMediaParts(msg.Media)...)
if len(parts) > 0 {
contents = append(contents, geminiContent{Role: "user", Parts: parts})
}
case "assistant":
content := geminiContent{Role: "model"}
if strings.TrimSpace(msg.Content) != "" {
content.Parts = append(content.Parts, geminiPart{Text: msg.Content})
}
for _, tc := range msg.ToolCalls {
toolName, toolArgs, thoughtSignature := normalizeStoredToolCall(tc)
if toolName == "" {
continue
}
if tc.ID != "" {
toolCallNames[tc.ID] = toolName
}
part := geminiPart{
FunctionCall: &geminiFunctionCall{
Name: toolName,
Args: toolArgs,
ID: tc.ID,
},
}
if thoughtSignature != "" {
part.ThoughtSignature = thoughtSignature
}
content.Parts = append(content.Parts, part)
}
if len(content.Parts) > 0 {
contents = append(contents, content)
}
case "tool":
toolName := resolveToolResponseName(msg.ToolCallID, toolCallNames)
contents = append(contents, geminiContent{
Role: "user",
Parts: []geminiPart{{
FunctionResponse: buildGeminiFunctionResponse(toolName, msg.ToolCallID, msg.Content, msg.Media),
}},
})
}
}
body := map[string]any{
"contents": contents,
}
if len(systemPrompts) > 0 {
systemParts := make([]geminiPart, 0, len(systemPrompts))
for _, prompt := range systemPrompts {
systemParts = append(systemParts, geminiPart{Text: prompt})
}
body["systemInstruction"] = &geminiContent{Parts: systemParts}
}
if len(tools) > 0 {
funcDecls := make([]geminiFunctionDeclaration, 0, len(tools))
for _, t := range tools {
if t.Type != "function" {
continue
}
funcDecls = append(funcDecls, geminiFunctionDeclaration{
Name: t.Function.Name,
Description: t.Function.Description,
Parameters: sanitizeSchemaForGemini(t.Function.Parameters),
})
}
if len(funcDecls) > 0 {
body["tools"] = []geminiTool{{FunctionDeclarations: funcDecls}}
}
}
generationConfig := make(map[string]any)
if val, ok := options["max_tokens"]; ok {
if maxTokens, ok := val.(int); ok && maxTokens > 0 {
generationConfig["maxOutputTokens"] = maxTokens
} else if maxTokens, ok := val.(float64); ok && maxTokens > 0 {
generationConfig["maxOutputTokens"] = int(maxTokens)
}
}
if temp, ok := options["temperature"].(float64); ok {
generationConfig["temperature"] = temp
}
if thinkingConfig := buildGeminiThinkingConfig(model, options); len(thinkingConfig) > 0 {
generationConfig["thinkingConfig"] = thinkingConfig
}
if len(generationConfig) > 0 {
body["generationConfig"] = generationConfig
}
for k, v := range p.extraBody {
body[k] = v
}
return body
}
func normalizeGeminiModel(model string) string {
model = strings.TrimSpace(model)
model = strings.TrimPrefix(model, "models/")
if strings.Contains(model, "/") {
_, modelID := ExtractProtocol(model)
if modelID != "" {
return modelID
}
}
if model == "" {
return geminiDefaultModel
}
return model
}
func mapGeminiThinkingLevel(level string) string {
switch strings.ToLower(strings.TrimSpace(level)) {
case "minimal", "off":
return "minimal"
case "low":
return "low"
case "medium":
return "medium"
case "high", "xhigh", "adaptive":
return "high"
default:
return ""
}
}
func buildGeminiThinkingConfig(model string, options map[string]any) map[string]any {
if !geminiModelSupportsThinkingConfig(model) {
return nil
}
config := map[string]any{}
rawLevel, _ := options["thinking_level"].(string)
rawLevel = strings.ToLower(strings.TrimSpace(rawLevel))
if rawLevel == "" {
// Align with agent-level default: unset means ThinkingOff.
rawLevel = "off"
}
includeThoughts := rawLevel != "off" && rawLevel != "minimal"
config["includeThoughts"] = includeThoughts
if isGemini25Model(model) {
if isGemini25ProModel(model) && (rawLevel == "off" || rawLevel == "minimal") {
// Gemini 2.5 Pro cannot disable thinking; keep model-default thinking.
return config
}
if budget, ok := mapGeminiThinkingBudget(rawLevel); ok {
config["thinkingBudget"] = budget
}
return config
}
if isGemini3ProModel(model) && (rawLevel == "off" || rawLevel == "minimal") {
// Gemini 3.x Pro does not support minimal thinking level.
return config
}
if thinkingLevel := mapGeminiThinkingLevel(rawLevel); thinkingLevel != "" {
config["thinkingLevel"] = thinkingLevel
}
return config
}
func geminiModelSupportsThinkingConfig(model string) bool {
lowerModel := strings.ToLower(strings.TrimSpace(model))
return strings.Contains(lowerModel, "gemini-3") || isGemini25Model(lowerModel)
}
func isGemini25Model(model string) bool {
lowerModel := strings.ToLower(strings.TrimSpace(model))
return strings.Contains(lowerModel, "gemini-2.5") || strings.Contains(lowerModel, "gemini-25")
}
func isGemini25ProModel(model string) bool {
lowerModel := strings.ToLower(strings.TrimSpace(model))
return isGemini25Model(lowerModel) && strings.Contains(lowerModel, "pro")
}
func isGemini3ProModel(model string) bool {
lowerModel := strings.ToLower(strings.TrimSpace(model))
return strings.Contains(lowerModel, "gemini-3") && strings.Contains(lowerModel, "pro")
}
func mapGeminiThinkingBudget(level string) (int, bool) {
level = strings.ToLower(strings.TrimSpace(level))
if level == "" {
return 0, false
}
switch level {
case "adaptive":
return -1, true
case "minimal":
return 0, true
case "off":
return 0, true
case "low":
return 1024, true
case "medium":
return 4096, true
case "high":
return 8192, true
case "xhigh":
return 16384, true
default:
return 0, false
}
}
func parseGeminiResponse(resp *geminiGenerateContentResponse) *LLMResponse {
contentParts := make([]string, 0)
reasoningParts := make([]string, 0)
toolCalls := make([]ToolCall, 0)
finishReason := ""
for _, candidate := range resp.Candidates {
for _, part := range candidate.Content.Parts {
if part.Text != "" {
if part.Thought {
reasoningParts = append(reasoningParts, part.Text)
} else {
contentParts = append(contentParts, part.Text)
}
}
if part.FunctionCall != nil {
toolCalls = append(toolCalls, buildGeminiToolCall(part))
}
}
if candidate.FinishReason != "" {
finishReason = candidate.FinishReason
}
}
var usage *UsageInfo
if resp.UsageMetadata.TotalTokenCount > 0 {
usage = &UsageInfo{
PromptTokens: resp.UsageMetadata.PromptTokenCount,
CompletionTokens: resp.UsageMetadata.CandidatesTokenCount,
TotalTokens: resp.UsageMetadata.TotalTokenCount,
}
}
return &LLMResponse{
Content: strings.Join(contentParts, ""),
ReasoningContent: strings.Join(reasoningParts, ""),
ToolCalls: toolCalls,
FinishReason: normalizeGeminiFinishReason(finishReason, len(toolCalls)),
Usage: usage,
}
}
func parseGeminiStreamResponse(
ctx context.Context,
reader io.Reader,
onChunk func(accumulated string),
) (*LLMResponse, error) {
var contentBuilder strings.Builder
var reasoningBuilder strings.Builder
var finishReason string
var usage *UsageInfo
toolCallsByID := make(map[string]ToolCall)
toolCallOrder := make([]string, 0)
fallbackIndex := 0
scanner := bufio.NewScanner(reader)
scanner.Buffer(make([]byte, 0, 1024*1024), 10*1024*1024)
for scanner.Scan() {
if err := ctx.Err(); err != nil {
return nil, err
}
line := scanner.Text()
if !strings.HasPrefix(line, "data: ") {
continue
}
data := strings.TrimSpace(strings.TrimPrefix(line, "data: "))
if data == "" {
continue
}
if data == "[DONE]" {
break
}
var chunk geminiGenerateContentResponse
if err := json.Unmarshal([]byte(data), &chunk); err != nil {
return nil, fmt.Errorf("invalid gemini stream chunk: %w", err)
}
for _, candidate := range chunk.Candidates {
for _, part := range candidate.Content.Parts {
if part.Text != "" {
if part.Thought {
reasoningBuilder.WriteString(part.Text)
} else {
contentBuilder.WriteString(part.Text)
if onChunk != nil {
onChunk(contentBuilder.String())
}
}
}
if part.FunctionCall != nil {
tc := buildGeminiToolCall(part)
if strings.TrimSpace(tc.Name) == "" {
continue
}
key := strings.TrimSpace(part.FunctionCall.ID)
if key == "" {
if len(toolCallOrder) > 0 {
lastKey := toolCallOrder[len(toolCallOrder)-1]
if lastTC, exists := toolCallsByID[lastKey]; exists && lastTC.Name == tc.Name {
key = lastKey
}
}
if key == "" {
fallbackIndex++
key = fmt.Sprintf("%s#%d", tc.Name, fallbackIndex)
}
}
tc.ID = key
if _, exists := toolCallsByID[key]; !exists {
toolCallOrder = append(toolCallOrder, key)
}
toolCallsByID[key] = tc
}
}
if candidate.FinishReason != "" {
finishReason = candidate.FinishReason
}
}
if chunk.UsageMetadata.TotalTokenCount > 0 {
usage = &UsageInfo{
PromptTokens: chunk.UsageMetadata.PromptTokenCount,
CompletionTokens: chunk.UsageMetadata.CandidatesTokenCount,
TotalTokens: chunk.UsageMetadata.TotalTokenCount,
}
}
}
if err := scanner.Err(); err != nil {
return nil, fmt.Errorf("streaming read error: %w", err)
}
toolCalls := make([]ToolCall, 0, len(toolCallOrder))
for _, key := range toolCallOrder {
toolCalls = append(toolCalls, toolCallsByID[key])
}
return &LLMResponse{
Content: contentBuilder.String(),
ReasoningContent: reasoningBuilder.String(),
ToolCalls: toolCalls,
FinishReason: normalizeGeminiFinishReason(finishReason, len(toolCalls)),
Usage: usage,
}, nil
}
func normalizeGeminiFinishReason(reason string, toolCalls int) string {
if toolCalls > 0 {
return "tool_calls"
}
switch strings.ToUpper(strings.TrimSpace(reason)) {
case "MAX_TOKENS":
return "length"
case "", "STOP":
return "stop"
default:
return strings.ToLower(strings.TrimSpace(reason))
}
}
func buildGeminiToolCall(part geminiPart) ToolCall {
if part.FunctionCall == nil {
return ToolCall{}
}
args := part.FunctionCall.Args
if args == nil {
args = make(map[string]any)
}
argsJSON, _ := json.Marshal(args)
thoughtSignature := extractPartThoughtSignature(part.ThoughtSignature, part.ThoughtSignatureSnake)
toolCall := ToolCall{
ID: part.FunctionCall.ID,
Name: part.FunctionCall.Name,
Arguments: args,
ThoughtSignature: thoughtSignature,
Function: &FunctionCall{
Name: part.FunctionCall.Name,
Arguments: string(argsJSON),
ThoughtSignature: thoughtSignature,
},
}
if thoughtSignature != "" {
toolCall.ExtraContent = &ExtraContent{
Google: &GoogleExtra{ThoughtSignature: thoughtSignature},
}
}
if strings.TrimSpace(toolCall.ID) == "" {
toolCall.ID = fmt.Sprintf("call_%s_%d", toolCall.Name, time.Now().UnixNano())
}
return toolCall
}
func buildInlineMediaParts(media []string) []geminiPart {
parts := make([]geminiPart, 0, len(media))
for _, mediaURL := range media {
mimeType, data, ok := parseBase64DataURL(mediaURL)
if !ok {
continue
}
parts = append(parts, geminiPart{
InlineData: &geminiInlineData{
MIMEType: mimeType,
Data: data,
},
})
}
return parts
}
func buildGeminiFunctionResponse(
toolName string,
toolCallID string,
result string,
media []string,
) *geminiFunctionResponse {
response := &geminiFunctionResponse{
ID: toolCallID,
Name: toolName,
Response: map[string]any{
"result": result,
},
}
if parts := buildFunctionResponseMediaParts(media); len(parts) > 0 {
response.Parts = parts
}
return response
}
func buildFunctionResponseMediaParts(media []string) []geminiFunctionResponsePart {
parts := make([]geminiFunctionResponsePart, 0, len(media))
for i, mediaURL := range media {
mimeType, data, ok := parseBase64DataURL(mediaURL)
if !ok {
continue
}
parts = append(parts, geminiFunctionResponsePart{
InlineData: &geminiInlineData{
MIMEType: mimeType,
Data: data,
DisplayName: defaultFunctionResponseDisplayName(mimeType, i+1),
},
})
}
return parts
}
func defaultFunctionResponseDisplayName(mimeType string, index int) string {
suffix := "bin"
switch strings.ToLower(strings.TrimSpace(mimeType)) {
case "image/png":
suffix = "png"
case "image/jpeg":
suffix = "jpg"
case "image/webp":
suffix = "webp"
case "application/pdf":
suffix = "pdf"
case "text/plain":
suffix = "txt"
}
return fmt.Sprintf("attachment-%d.%s", index, suffix)
}
func parseBase64DataURL(mediaURL string) (mimeType string, data string, ok bool) {
if !strings.HasPrefix(mediaURL, "data:") {
return "", "", false
}
payload := strings.TrimPrefix(mediaURL, "data:")
header, data, found := strings.Cut(payload, ",")
if !found {
return "", "", false
}
mimeType, params, _ := strings.Cut(header, ";")
mimeType = strings.TrimSpace(mimeType)
data = strings.TrimSpace(data)
if mimeType == "" || data == "" {
return "", "", false
}
if !strings.Contains(strings.ToLower(params), "base64") {
return "", "", false
}
return mimeType, data, true
}
func cloneAnyMap(in map[string]any) map[string]any {
if len(in) == 0 {
return nil
}
out := make(map[string]any, len(in))
for k, v := range in {
out[k] = v
}
return out
}
func cloneStringMap(in map[string]string) map[string]string {
if len(in) == 0 {
return nil
}
out := make(map[string]string, len(in))
for k, v := range in {
out[k] = v
}
return out
}
type geminiGenerateContentResponse struct {
Candidates []struct {
Content struct {
Role string `json:"role"`
Parts []geminiPart `json:"parts"`
} `json:"content"`
FinishReason string `json:"finishReason"`
} `json:"candidates"`
UsageMetadata struct {
PromptTokenCount int `json:"promptTokenCount"`
CandidatesTokenCount int `json:"candidatesTokenCount"`
TotalTokenCount int `json:"totalTokenCount"`
} `json:"usageMetadata"`
}
type geminiContent struct {
Role string `json:"role,omitempty"`
Parts []geminiPart `json:"parts"`
}
type geminiPart struct {
Text string `json:"text,omitempty"`
Thought bool `json:"thought,omitempty"`
ThoughtSignature string `json:"thoughtSignature,omitempty"`
ThoughtSignatureSnake string `json:"thought_signature,omitempty"`
InlineData *geminiInlineData `json:"inlineData,omitempty"`
FunctionCall *geminiFunctionCall `json:"functionCall,omitempty"`
FunctionResponse *geminiFunctionResponse `json:"functionResponse,omitempty"`
}
type geminiInlineData struct {
MIMEType string `json:"mimeType"`
Data string `json:"data"`
DisplayName string `json:"displayName,omitempty"`
}
type geminiFunctionCall struct {
ID string `json:"id,omitempty"`
Name string `json:"name"`
Args map[string]any `json:"args,omitempty"`
}
type geminiFunctionResponse struct {
ID string `json:"id,omitempty"`
Name string `json:"name"`
Response map[string]any `json:"response"`
Parts []geminiFunctionResponsePart `json:"parts,omitempty"`
}
type geminiFunctionResponsePart struct {
InlineData *geminiInlineData `json:"inlineData,omitempty"`
}
type geminiTool struct {
FunctionDeclarations []geminiFunctionDeclaration `json:"functionDeclarations"`
}
type geminiFunctionDeclaration struct {
Name string `json:"name"`
Description string `json:"description,omitempty"`
Parameters any `json:"parameters,omitempty"`
}
+763
View File
@@ -0,0 +1,763 @@
package providers
import (
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
)
func TestGeminiProvider_ChatSeparatesThoughtAndToolCall(t *testing.T) {
var capturedBody map[string]any
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
t.Fatalf("method = %s, want POST", r.Method)
}
if !strings.Contains(r.URL.Path, ":generateContent") {
t.Fatalf("path = %s, expected generateContent endpoint", r.URL.Path)
}
if got := r.Header.Get("X-Goog-Api-Key"); got != "test-key" {
t.Fatalf("X-Goog-Api-Key = %q, want %q", got, "test-key")
}
if err := json.NewDecoder(r.Body).Decode(&capturedBody); err != nil {
t.Fatalf("decode request body: %v", err)
}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]any{
"candidates": []any{
map[string]any{
"content": map[string]any{
"role": "model",
"parts": []any{
map[string]any{"text": "hidden", "thought": true},
map[string]any{"text": "visible"},
map[string]any{
"functionCall": map[string]any{
"id": "call_1",
"name": "search",
"args": map[string]any{"q": "hi"},
},
"thoughtSignature": "sig-1",
},
},
},
"finishReason": "STOP",
},
},
"usageMetadata": map[string]any{
"promptTokenCount": 2,
"candidatesTokenCount": 3,
"totalTokenCount": 5,
},
})
}))
defer server.Close()
provider := NewGeminiProvider("test-key", server.URL, "", "picoclaw-test", 0, nil, nil)
resp, err := provider.Chat(
t.Context(),
[]Message{{Role: "user", Content: "hello"}},
nil,
"gemini-3-flash-preview",
map[string]any{"thinking_level": "high"},
)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
if resp.Content != "visible" {
t.Fatalf("Content = %q, want %q", resp.Content, "visible")
}
if resp.ReasoningContent != "hidden" {
t.Fatalf("ReasoningContent = %q, want %q", resp.ReasoningContent, "hidden")
}
if resp.FinishReason != "tool_calls" {
t.Fatalf("FinishReason = %q, want %q", resp.FinishReason, "tool_calls")
}
if resp.Usage == nil || resp.Usage.TotalTokens != 5 {
t.Fatalf("Usage = %#v, expected total tokens = 5", resp.Usage)
}
if len(resp.ToolCalls) != 1 {
t.Fatalf("ToolCalls len = %d, want 1", len(resp.ToolCalls))
}
if resp.ToolCalls[0].ID != "call_1" {
t.Fatalf("ToolCall ID = %q, want %q", resp.ToolCalls[0].ID, "call_1")
}
if resp.ToolCalls[0].Name != "search" {
t.Fatalf("ToolCall Name = %q, want %q", resp.ToolCalls[0].Name, "search")
}
if resp.ToolCalls[0].ThoughtSignature != "sig-1" {
t.Fatalf("ToolCall ThoughtSignature = %q, want %q", resp.ToolCalls[0].ThoughtSignature, "sig-1")
}
if resp.ToolCalls[0].Function == nil || !strings.Contains(resp.ToolCalls[0].Function.Arguments, `"q":"hi"`) {
t.Fatalf("ToolCall Function arguments = %#v, want q=hi", resp.ToolCalls[0].Function)
}
generationConfig, ok := capturedBody["generationConfig"].(map[string]any)
if !ok {
t.Fatalf("request missing generationConfig: %#v", capturedBody)
}
thinkingConfig, ok := generationConfig["thinkingConfig"].(map[string]any)
if !ok {
t.Fatalf("request missing thinkingConfig: %#v", generationConfig)
}
if includeThoughts, ok := thinkingConfig["includeThoughts"].(bool); !ok || !includeThoughts {
t.Fatalf("thinkingConfig.includeThoughts = %#v, want true", thinkingConfig["includeThoughts"])
}
if got := thinkingConfig["thinkingLevel"]; got != "high" {
t.Fatalf("thinkingConfig.thinkingLevel = %#v, want %q", got, "high")
}
}
func TestGeminiProvider_ChatStreamParsesThoughtTextAndToolCalls(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !strings.Contains(r.URL.Path, ":streamGenerateContent") {
t.Fatalf("path = %s, expected streamGenerateContent endpoint", r.URL.Path)
}
if got := r.URL.Query().Get("alt"); got != "sse" {
t.Fatalf("alt query = %q, want %q", got, "sse")
}
w.Header().Set("Content-Type", "text/event-stream")
flusher, ok := w.(http.Flusher)
if !ok {
t.Fatal("response writer is not flushable")
}
chunks := []map[string]any{
{
"candidates": []any{map[string]any{
"content": map[string]any{
"parts": []any{
map[string]any{"text": "think ", "thought": true},
map[string]any{"text": "Hello "},
},
},
}},
},
{
"candidates": []any{map[string]any{
"content": map[string]any{
"parts": []any{
map[string]any{"text": "World"},
map[string]any{
"functionCall": map[string]any{
"id": "call_stream",
"name": "search",
"args": map[string]any{"q": "stream"},
},
},
},
},
"finishReason": "STOP",
}},
"usageMetadata": map[string]any{
"promptTokenCount": 1,
"candidatesTokenCount": 2,
"totalTokenCount": 3,
},
},
}
for _, chunk := range chunks {
raw, err := json.Marshal(chunk)
if err != nil {
t.Fatalf("marshal chunk: %v", err)
}
if _, err := fmt.Fprintf(w, "data: %s\n\n", raw); err != nil {
t.Fatalf("write chunk: %v", err)
}
flusher.Flush()
}
_, _ = fmt.Fprint(w, "data: [DONE]\n\n")
flusher.Flush()
}))
defer server.Close()
provider := NewGeminiProvider("test-key", server.URL, "", "", 0, nil, nil)
updates := make([]string, 0)
resp, err := provider.ChatStream(
t.Context(),
[]Message{{Role: "user", Content: "hello"}},
nil,
"gemini-2.5-flash",
nil,
func(accumulated string) {
updates = append(updates, accumulated)
},
)
if err != nil {
t.Fatalf("ChatStream() error = %v", err)
}
if resp.Content != "Hello World" {
t.Fatalf("Content = %q, want %q", resp.Content, "Hello World")
}
if resp.ReasoningContent != "think " {
t.Fatalf("ReasoningContent = %q, want %q", resp.ReasoningContent, "think ")
}
if len(resp.ToolCalls) != 1 || resp.ToolCalls[0].ID != "call_stream" {
t.Fatalf("ToolCalls = %#v, want single call_stream", resp.ToolCalls)
}
if resp.FinishReason != "tool_calls" {
t.Fatalf("FinishReason = %q, want %q", resp.FinishReason, "tool_calls")
}
if resp.Usage == nil || resp.Usage.TotalTokens != 3 {
t.Fatalf("Usage = %#v, expected total tokens = 3", resp.Usage)
}
if len(updates) < 2 || updates[len(updates)-1] != "Hello World" {
t.Fatalf("stream updates = %#v, expected final accumulated text", updates)
}
}
func TestGeminiProvider_ChatStreamSkipsEmptyDataFrames(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
flusher, ok := w.(http.Flusher)
if !ok {
t.Fatal("response writer is not flushable")
}
_, _ = fmt.Fprint(w, "data: \n\n")
flusher.Flush()
chunk := map[string]any{
"candidates": []any{map[string]any{
"content": map[string]any{
"parts": []any{map[string]any{"text": "ok"}},
},
"finishReason": "STOP",
}},
}
raw, err := json.Marshal(chunk)
if err != nil {
t.Fatalf("marshal chunk: %v", err)
}
_, _ = fmt.Fprintf(w, "data: %s\n\n", raw)
flusher.Flush()
_, _ = fmt.Fprint(w, "data: [DONE]\n\n")
flusher.Flush()
}))
defer server.Close()
provider := NewGeminiProvider("test-key", server.URL, "", "", 0, nil, nil)
resp, err := provider.ChatStream(
t.Context(),
[]Message{{Role: "user", Content: "hello"}},
nil,
"gemini-2.5-flash",
nil,
nil,
)
if err != nil {
t.Fatalf("ChatStream() error = %v", err)
}
if resp.Content != "ok" {
t.Fatalf("Content = %q, want %q", resp.Content, "ok")
}
}
func TestGeminiProvider_ChatStreamReturnsErrorOnInvalidDataFrame(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
flusher, ok := w.(http.Flusher)
if !ok {
t.Fatal("response writer is not flushable")
}
_, _ = fmt.Fprint(w, "data: {invalid-json}\n\n")
flusher.Flush()
}))
defer server.Close()
provider := NewGeminiProvider("test-key", server.URL, "", "", 0, nil, nil)
_, err := provider.ChatStream(
t.Context(),
[]Message{{Role: "user", Content: "hello"}},
nil,
"gemini-2.5-flash",
nil,
nil,
)
if err == nil {
t.Fatal("ChatStream() expected error for invalid SSE data frame")
}
if !strings.Contains(err.Error(), "invalid gemini stream chunk") {
t.Fatalf("error = %v, want contains %q", err, "invalid gemini stream chunk")
}
}
func TestGeminiProvider_BuildRequestBody_UsesCamelCaseThoughtSignatureOnly(t *testing.T) {
provider := NewGeminiProvider("test-key", "https://example.com/v1beta", "", "", 0, nil, nil)
body := provider.buildRequestBody(
[]Message{{
Role: "assistant",
ToolCalls: []ToolCall{{
ID: "call_1",
Name: "search",
Arguments: map[string]any{"q": "hello"},
Function: &FunctionCall{
Name: "search",
Arguments: `{"q":"hello"}`,
ThoughtSignature: "sig-1",
},
}},
}},
nil,
"gemini-2.5-flash",
nil,
)
raw, err := json.Marshal(body)
if err != nil {
t.Fatalf("marshal request body: %v", err)
}
jsonBody := string(raw)
if !strings.Contains(jsonBody, `"thoughtSignature":"sig-1"`) {
t.Fatalf("request body = %s, expected camelCase thoughtSignature", jsonBody)
}
if strings.Contains(jsonBody, `"thought_signature"`) {
t.Fatalf("request body = %s, unexpected snake_case thought_signature", jsonBody)
}
}
func TestGeminiProvider_ChatStreamCoalescesToolCallWithoutWireID(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
flusher, ok := w.(http.Flusher)
if !ok {
t.Fatal("response writer is not flushable")
}
chunks := []map[string]any{
{
"candidates": []any{map[string]any{
"content": map[string]any{
"parts": []any{
map[string]any{
"functionCall": map[string]any{
"name": "search",
"args": map[string]any{"q": "first"},
},
},
},
},
}},
},
{
"candidates": []any{map[string]any{
"content": map[string]any{
"parts": []any{
map[string]any{
"functionCall": map[string]any{
"name": "search",
"args": map[string]any{"q": "second"},
},
},
},
},
"finishReason": "STOP",
}},
},
}
for _, chunk := range chunks {
raw, err := json.Marshal(chunk)
if err != nil {
t.Fatalf("marshal chunk: %v", err)
}
if _, err := fmt.Fprintf(w, "data: %s\n\n", raw); err != nil {
t.Fatalf("write chunk: %v", err)
}
flusher.Flush()
}
_, _ = fmt.Fprint(w, "data: [DONE]\n\n")
flusher.Flush()
}))
defer server.Close()
provider := NewGeminiProvider("test-key", server.URL, "", "", 0, nil, nil)
resp, err := provider.ChatStream(
t.Context(),
[]Message{{Role: "user", Content: "hello"}},
nil,
"gemini-2.5-flash",
nil,
nil,
)
if err != nil {
t.Fatalf("ChatStream() error = %v", err)
}
if len(resp.ToolCalls) != 1 {
t.Fatalf("ToolCalls len = %d, want 1", len(resp.ToolCalls))
}
tc := resp.ToolCalls[0]
if tc.ID != "search#1" {
t.Fatalf("ToolCall ID = %q, want %q", tc.ID, "search#1")
}
if tc.Name != "search" {
t.Fatalf("ToolCall Name = %q, want %q", tc.Name, "search")
}
if argQ, ok := tc.Arguments["q"].(string); !ok || argQ != "second" {
t.Fatalf("ToolCall Arguments = %#v, want q=second", tc.Arguments)
}
if resp.FinishReason != "tool_calls" {
t.Fatalf("FinishReason = %q, want %q", resp.FinishReason, "tool_calls")
}
}
func TestGeminiProvider_BuildRequestBodyIncludesMediaAndThinkingConfig(t *testing.T) {
provider := NewGeminiProvider("test-key", "https://example.com/v1beta", "", "", 0, nil, nil)
body := provider.buildRequestBody(
[]Message{{
Role: "user",
Content: "analyze attachments",
Media: []string{
"data:application/pdf;base64,UEZERGF0YQ==",
"data:image/png;base64,aW1hZ2VEYXRh",
},
}},
nil,
"gemini-3-flash-preview",
map[string]any{
"thinking_level": "low",
"max_tokens": 128,
"temperature": 0.2,
},
)
contents, ok := body["contents"].([]geminiContent)
if !ok || len(contents) != 1 {
t.Fatalf("contents = %#v, want one gemini content", body["contents"])
}
parts := contents[0].Parts
mimeSet := map[string]bool{}
for _, part := range parts {
if part.InlineData != nil {
mimeSet[part.InlineData.MIMEType] = true
}
}
if !mimeSet["application/pdf"] {
t.Fatalf("inline media missing application/pdf: %#v", parts)
}
if !mimeSet["image/png"] {
t.Fatalf("inline media missing image/png: %#v", parts)
}
generationConfig, ok := body["generationConfig"].(map[string]any)
if !ok {
t.Fatalf("generationConfig = %#v, want map", body["generationConfig"])
}
if got := generationConfig["maxOutputTokens"]; got != 128 {
t.Fatalf("maxOutputTokens = %#v, want 128", got)
}
if got := generationConfig["temperature"]; got != 0.2 {
t.Fatalf("temperature = %#v, want 0.2", got)
}
thinkingConfig, ok := generationConfig["thinkingConfig"].(map[string]any)
if !ok {
t.Fatalf("thinkingConfig = %#v, want map", generationConfig["thinkingConfig"])
}
if includeThoughts, ok := thinkingConfig["includeThoughts"].(bool); !ok || !includeThoughts {
t.Fatalf("includeThoughts = %#v, want true", thinkingConfig["includeThoughts"])
}
if got := thinkingConfig["thinkingLevel"]; got != "low" {
t.Fatalf("thinkingLevel = %#v, want %q", got, "low")
}
}
func TestGeminiProvider_BuildRequestBody_UsesThinkingBudgetForGemini25(t *testing.T) {
provider := NewGeminiProvider("test-key", "https://example.com/v1beta", "", "", 0, nil, nil)
body := provider.buildRequestBody(
[]Message{{Role: "user", Content: "hello"}},
nil,
"gemini-2.5-flash",
map[string]any{"thinking_level": "medium"},
)
generationConfig, ok := body["generationConfig"].(map[string]any)
if !ok {
t.Fatalf("generationConfig = %#v, want map", body["generationConfig"])
}
thinkingConfig, ok := generationConfig["thinkingConfig"].(map[string]any)
if !ok {
t.Fatalf("thinkingConfig = %#v, want map", generationConfig["thinkingConfig"])
}
if got := thinkingConfig["thinkingBudget"]; got != 4096 {
t.Fatalf("thinkingBudget = %#v, want 4096", got)
}
if _, hasLevel := thinkingConfig["thinkingLevel"]; hasLevel {
t.Fatalf("thinkingLevel should not be set for Gemini 2.5: %#v", thinkingConfig)
}
}
func TestGeminiProvider_BuildRequestBody_OmitsThinkingConfigForGemini20(t *testing.T) {
provider := NewGeminiProvider("test-key", "https://example.com/v1beta", "", "", 0, nil, nil)
body := provider.buildRequestBody(
[]Message{{Role: "user", Content: "hello"}},
nil,
"gemini-2.0-flash-exp",
map[string]any{"thinking_level": "high"},
)
if _, ok := body["generationConfig"]; ok {
t.Fatalf("generationConfig should be omitted for Gemini 2.0 when only thinking_level is set: %#v", body)
}
}
func TestGeminiProvider_BuildRequestBody_DefaultsThinkingOffForGemini25(t *testing.T) {
provider := NewGeminiProvider("test-key", "https://example.com/v1beta", "", "", 0, nil, nil)
body := provider.buildRequestBody(
[]Message{{Role: "user", Content: "hello"}},
nil,
"gemini-2.5-flash",
nil,
)
generationConfig, ok := body["generationConfig"].(map[string]any)
if !ok {
t.Fatalf("generationConfig = %#v, want map", body["generationConfig"])
}
thinkingConfig, ok := generationConfig["thinkingConfig"].(map[string]any)
if !ok {
t.Fatalf("thinkingConfig = %#v, want map", generationConfig["thinkingConfig"])
}
if got := thinkingConfig["thinkingBudget"]; got != 0 {
t.Fatalf("thinkingBudget = %#v, want 0 for default/off", got)
}
if includeThoughts, ok := thinkingConfig["includeThoughts"].(bool); !ok || includeThoughts {
t.Fatalf("includeThoughts = %#v, want false for default/off", thinkingConfig["includeThoughts"])
}
}
func TestGeminiProvider_BuildRequestBody_DefaultsThinkingOffForGemini3(t *testing.T) {
provider := NewGeminiProvider("test-key", "https://example.com/v1beta", "", "", 0, nil, nil)
body := provider.buildRequestBody(
[]Message{{Role: "user", Content: "hello"}},
nil,
"gemini-3-flash-preview",
nil,
)
generationConfig, ok := body["generationConfig"].(map[string]any)
if !ok {
t.Fatalf("generationConfig = %#v, want map", body["generationConfig"])
}
thinkingConfig, ok := generationConfig["thinkingConfig"].(map[string]any)
if !ok {
t.Fatalf("thinkingConfig = %#v, want map", generationConfig["thinkingConfig"])
}
if got := thinkingConfig["thinkingLevel"]; got != "minimal" {
t.Fatalf("thinkingLevel = %#v, want minimal for default/off", got)
}
if includeThoughts, ok := thinkingConfig["includeThoughts"].(bool); !ok || includeThoughts {
t.Fatalf("includeThoughts = %#v, want false for default/off", thinkingConfig["includeThoughts"])
}
}
func TestGeminiProvider_BuildRequestBody_DefaultsThinkingOffForGemini25Pro(t *testing.T) {
provider := NewGeminiProvider("test-key", "https://example.com/v1beta", "", "", 0, nil, nil)
body := provider.buildRequestBody(
[]Message{{Role: "user", Content: "hello"}},
nil,
"gemini-2.5-pro",
nil,
)
generationConfig, ok := body["generationConfig"].(map[string]any)
if !ok {
t.Fatalf("generationConfig = %#v, want map", body["generationConfig"])
}
thinkingConfig, ok := generationConfig["thinkingConfig"].(map[string]any)
if !ok {
t.Fatalf("thinkingConfig = %#v, want map", generationConfig["thinkingConfig"])
}
if includeThoughts, ok := thinkingConfig["includeThoughts"].(bool); !ok || includeThoughts {
t.Fatalf("includeThoughts = %#v, want false for default/off", thinkingConfig["includeThoughts"])
}
if _, hasBudget := thinkingConfig["thinkingBudget"]; hasBudget {
t.Fatalf("thinkingBudget should be omitted for Gemini 2.5 Pro default/off: %#v", thinkingConfig)
}
}
func TestGeminiProvider_BuildRequestBody_DefaultsThinkingOffForGemini31Pro(t *testing.T) {
provider := NewGeminiProvider("test-key", "https://example.com/v1beta", "", "", 0, nil, nil)
body := provider.buildRequestBody(
[]Message{{Role: "user", Content: "hello"}},
nil,
"gemini-3.1-pro",
nil,
)
generationConfig, ok := body["generationConfig"].(map[string]any)
if !ok {
t.Fatalf("generationConfig = %#v, want map", body["generationConfig"])
}
thinkingConfig, ok := generationConfig["thinkingConfig"].(map[string]any)
if !ok {
t.Fatalf("thinkingConfig = %#v, want map", generationConfig["thinkingConfig"])
}
if includeThoughts, ok := thinkingConfig["includeThoughts"].(bool); !ok || includeThoughts {
t.Fatalf("includeThoughts = %#v, want false for default/off", thinkingConfig["includeThoughts"])
}
if _, hasLevel := thinkingConfig["thinkingLevel"]; hasLevel {
t.Fatalf("thinkingLevel should be omitted for Gemini 3.1 Pro default/off: %#v", thinkingConfig)
}
}
func TestGeminiProvider_BuildRequestBody_PreservesMultipleSystemMessages(t *testing.T) {
provider := NewGeminiProvider("test-key", "https://example.com/v1beta", "", "", 0, nil, nil)
body := provider.buildRequestBody(
[]Message{
{Role: "system", Content: "You are helpful."},
{Role: "system", Content: "Be concise."},
{Role: "user", Content: "hello"},
},
nil,
"gemini-3-flash-preview",
nil,
)
systemInstruction, ok := body["systemInstruction"].(*geminiContent)
if !ok || systemInstruction == nil {
t.Fatalf("systemInstruction = %#v, want *geminiContent", body["systemInstruction"])
}
if len(systemInstruction.Parts) != 2 {
t.Fatalf("systemInstruction.Parts len = %d, want 2", len(systemInstruction.Parts))
}
if systemInstruction.Parts[0].Text != "You are helpful." || systemInstruction.Parts[1].Text != "Be concise." {
t.Fatalf("systemInstruction.Parts = %#v, want ordered system prompts", systemInstruction.Parts)
}
}
func TestGeminiProvider_BuildRequestBody_PreservesToolResponseMedia(t *testing.T) {
provider := NewGeminiProvider("test-key", "https://example.com/v1beta", "", "", 0, nil, nil)
body := provider.buildRequestBody(
[]Message{
{
Role: "assistant",
ToolCalls: []ToolCall{{
ID: "call_1",
Name: "load_image",
Arguments: map[string]any{"path": "demo.png"},
}},
},
{
Role: "tool",
ToolCallID: "call_1",
Content: "tool result",
Media: []string{
"data:image/png;base64,aW1hZ2VEYXRh",
"data:application/pdf;base64,UEZERGF0YQ==",
},
},
},
nil,
"gemini-3-flash-preview",
nil,
)
contents, ok := body["contents"].([]geminiContent)
if !ok || len(contents) != 2 {
t.Fatalf("contents = %#v, want two content entries", body["contents"])
}
parts := contents[1].Parts
if len(parts) != 1 || parts[0].FunctionResponse == nil {
t.Fatalf("tool response part = %#v, want functionResponse", parts)
}
response := parts[0].FunctionResponse
if response.Name != "load_image" {
t.Fatalf("functionResponse.Name = %q, want %q", response.Name, "load_image")
}
if response.Response["result"] != "tool result" {
t.Fatalf("functionResponse.Response = %#v, want result=tool result", response.Response)
}
if len(response.Parts) != 2 {
t.Fatalf("functionResponse.Parts len = %d, want 2", len(response.Parts))
}
}
func TestGeminiProvider_ChatAllowsCustomAuthHeaderWithoutAPIKey(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if got := r.Header.Get("Authorization"); got != "Bearer test-token" {
t.Fatalf("Authorization = %q, want %q", got, "Bearer test-token")
}
if got := r.Header.Get("X-Goog-Api-Key"); got != "" {
t.Fatalf("X-Goog-Api-Key = %q, want empty", got)
}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]any{
"candidates": []any{
map[string]any{
"content": map[string]any{
"parts": []any{map[string]any{"text": "ok"}},
},
"finishReason": "STOP",
},
},
})
}))
defer server.Close()
provider := NewGeminiProvider(
"",
server.URL,
"",
"",
0,
nil,
map[string]string{"Authorization": "Bearer test-token"},
)
resp, err := provider.Chat(
t.Context(),
[]Message{{Role: "user", Content: "hello"}},
nil,
"gemini-2.5-flash",
nil,
)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
if resp.Content != "ok" {
t.Fatalf("Content = %q, want %q", resp.Content, "ok")
}
}
func TestGeminiProvider_ChatAllowsMissingAPIKeyForCustomAPIBase(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if got := r.Header.Get("X-Goog-Api-Key"); got != "" {
t.Fatalf("X-Goog-Api-Key = %q, want empty", got)
}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]any{
"candidates": []any{
map[string]any{
"content": map[string]any{"parts": []any{map[string]any{"text": "ok"}}},
"finishReason": "STOP",
},
},
})
}))
defer server.Close()
provider := NewGeminiProvider("", server.URL, "", "", 0, nil, nil)
resp, err := provider.Chat(
t.Context(),
[]Message{{Role: "user", Content: "hello"}},
nil,
"gemini-2.5-flash",
nil,
)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
if resp.Content != "ok" {
t.Fatalf("Content = %q, want %q", resp.Content, "ok")
}
}
+2 -3
View File
@@ -8,6 +8,7 @@ import (
"fmt"
"io"
"log"
"maps"
"net/http"
"net/url"
"strings"
@@ -181,9 +182,7 @@ func (p *Provider) buildRequestBody(
// Merge extra body fields configured per-provider/model.
// These are injected last so they take precedence over defaults.
for k, v := range p.extraBody {
requestBody[k] = v
}
maps.Copy(requestBody, p.extraBody)
return requestBody
}
+13
View File
@@ -281,6 +281,12 @@ func visibleSessionMessages(messages []providers.Message, toolFeedbackMaxArgsLen
}
case "assistant":
// Reasoning-only assistant messages are transient display artifacts and
// should not be restored from session history.
if assistantMessageTransientThought(msg) {
continue
}
toolSummaryMessages := visibleAssistantToolSummaryMessages(msg.ToolCalls, toolFeedbackMaxArgsLength)
if len(toolSummaryMessages) > 0 {
transcript = append(transcript, toolSummaryMessages...)
@@ -309,6 +315,13 @@ func visibleSessionMessages(messages []providers.Message, toolFeedbackMaxArgsLen
return transcript
}
func assistantMessageTransientThought(msg providers.Message) bool {
return strings.TrimSpace(msg.Content) == "" &&
strings.TrimSpace(msg.ReasoningContent) != "" &&
len(msg.ToolCalls) == 0 &&
len(msg.Media) == 0
}
func assistantMessageInternalOnly(msg providers.Message) bool {
return strings.TrimSpace(msg.Content) == handledToolResponseSummaryText
}
+53
View File
@@ -218,6 +218,59 @@ func TestHandleGetSession_JSONLStorage(t *testing.T) {
}
}
func TestHandleGetSession_OmitsTransientThoughtMessages(t *testing.T) {
configPath, cleanup := setupOAuthTestEnv(t)
defer cleanup()
dir := sessionsTestDir(t, configPath)
store, err := memory.NewJSONLStore(dir)
if err != nil {
t.Fatalf("NewJSONLStore() error = %v", err)
}
sessionKey := picoSessionPrefix + "detail-transient-thought"
for _, msg := range []providers.Message{
{Role: "user", Content: "hello"},
{Role: "assistant", ReasoningContent: "internal chain of thought"},
{Role: "assistant", Content: "final visible answer"},
} {
if err := store.AddFullMessage(nil, sessionKey, msg); err != nil {
t.Fatalf("AddFullMessage() error = %v", err)
}
}
h := NewHandler(configPath)
mux := http.NewServeMux()
h.RegisterRoutes(mux)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/sessions/detail-transient-thought", nil)
mux.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String())
}
var resp struct {
Messages []struct {
Role string `json:"role"`
Content string `json:"content"`
} `json:"messages"`
}
if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil {
t.Fatalf("Unmarshal() error = %v", err)
}
if len(resp.Messages) != 2 {
t.Fatalf("len(resp.Messages) = %d, want 2", len(resp.Messages))
}
if resp.Messages[0].Role != "user" || resp.Messages[0].Content != "hello" {
t.Fatalf("first message = %#v, want user/hello", resp.Messages[0])
}
if resp.Messages[1].Role != "assistant" || resp.Messages[1].Content != "final visible answer" {
t.Fatalf("second message = %#v, want assistant/final visible answer", resp.Messages[1])
}
}
func TestHandleGetSession_ReconstructsVisibleMessageToolOutput(t *testing.T) {
configPath, cleanup := setupOAuthTestEnv(t)
defer cleanup()
@@ -1,5 +1,6 @@
import { IconCheck, IconCopy } from "@tabler/icons-react"
import { IconBrain, IconCheck, IconCopy } from "@tabler/icons-react"
import { useState } from "react"
import { useTranslation } from "react-i18next"
import ReactMarkdown from "react-markdown"
import rehypeRaw from "rehype-raw"
import rehypeSanitize from "rehype-sanitize"
@@ -7,16 +8,20 @@ import remarkGfm from "remark-gfm"
import { Button } from "@/components/ui/button"
import { formatMessageTime } from "@/hooks/use-pico-chat"
import { cn } from "@/lib/utils"
interface AssistantMessageProps {
content: string
isThought?: boolean
timestamp?: string | number
}
export function AssistantMessage({
content,
isThought = false,
timestamp = "",
}: AssistantMessageProps) {
const { t } = useTranslation()
const [isCopied, setIsCopied] = useState(false)
const formattedTimestamp =
timestamp !== "" ? formatMessageTime(timestamp) : ""
@@ -33,6 +38,12 @@ export function AssistantMessage({
<div className="text-muted-foreground flex items-center justify-between gap-2 px-1 text-xs opacity-70">
<div className="flex items-center gap-2">
<span>PicoClaw</span>
{isThought && (
<span className="inline-flex items-center gap-1 rounded-full border border-amber-300/80 bg-amber-100/80 px-2 py-0.5 text-[11px] font-medium text-amber-800 dark:border-amber-500/40 dark:bg-amber-500/15 dark:text-amber-200">
<IconBrain className="size-3" />
<span>{t("chat.reasoningLabel")}</span>
</span>
)}
{formattedTimestamp && (
<>
<span className="opacity-50"></span>
@@ -42,8 +53,22 @@ export function AssistantMessage({
</div>
</div>
<div className="bg-card text-card-foreground relative overflow-hidden rounded-xl border">
<div className="prose dark:prose-invert prose-p:my-2 prose-pre:my-2 prose-pre:overflow-x-auto prose-pre:rounded-lg prose-pre:border prose-pre:bg-zinc-950 prose-pre:p-3 max-w-none p-4 text-[15px] leading-relaxed [overflow-wrap:anywhere] break-words">
<div
className={cn(
"relative overflow-hidden rounded-xl border",
isThought
? "border-amber-200/90 bg-amber-50/70 text-amber-950 dark:border-amber-500/35 dark:bg-amber-500/10 dark:text-amber-100"
: "bg-card text-card-foreground",
)}
>
<div
className={cn(
"prose dark:prose-invert prose-pre:my-2 prose-pre:overflow-x-auto prose-pre:rounded-lg prose-pre:border prose-pre:bg-zinc-950 prose-pre:p-3 max-w-none [overflow-wrap:anywhere] break-words",
isThought
? "prose-p:my-1.5 p-3 text-[13px] leading-relaxed opacity-90"
: "prose-p:my-2 p-4 text-[15px] leading-relaxed",
)}
>
<ReactMarkdown
remarkPlugins={[remarkGfm]}
rehypePlugins={[rehypeRaw, rehypeSanitize]}
@@ -54,7 +79,12 @@ export function AssistantMessage({
<Button
variant="ghost"
size="icon"
className="bg-background/50 hover:bg-background/80 absolute top-2 right-2 h-7 w-7 opacity-0 transition-opacity group-hover:opacity-100"
className={cn(
"absolute top-2 right-2 h-7 w-7 opacity-0 transition-opacity group-hover:opacity-100",
isThought
? "bg-amber-100/70 hover:bg-amber-200/80 dark:bg-amber-500/20 dark:hover:bg-amber-400/30"
: "bg-background/50 hover:bg-background/80",
)}
onClick={handleCopy}
>
{isCopied ? (
@@ -247,6 +247,7 @@ export function ChatPage() {
{msg.role === "assistant" ? (
<AssistantMessage
content={msg.content}
isThought={msg.kind === "thought"}
timestamp={msg.timestamp}
/>
) : (
+2 -1
View File
@@ -24,6 +24,7 @@ export async function loadSessionMessages(
id: `hist-${index}-${Date.now()}`,
role: message.role,
content: message.content,
kind: message.role === "assistant" ? "normal" : undefined,
attachments: toChatAttachments(message.media),
timestamp: fallbackTime,
}))
@@ -50,7 +51,7 @@ function messageSignature(message: ChatMessage): string {
return `${message.role}\u0000${message.content}\u0000${normalizeMessageTimestamp(
message.timestamp,
)}\u0000${attachmentSignature}`
)}\u0000${message.kind ?? ""}\u0000${attachmentSignature}`
}
function comparableTimestamp(timestamp: number | string): number {
+25 -2
View File
@@ -1,7 +1,10 @@
import { toast } from "sonner"
import { normalizeUnixTimestamp } from "@/features/chat/state"
import { updateChatStore } from "@/store/chat"
import {
type AssistantMessageKind,
updateChatStore,
} from "@/store/chat"
export interface PicoMessage {
type: string
@@ -11,6 +14,16 @@ export interface PicoMessage {
payload?: Record<string, unknown>
}
function parseAssistantMessageKind(
payload: Record<string, unknown>,
): AssistantMessageKind {
return payload.thought === true ? "thought" : "normal"
}
function hasAssistantKindPayload(payload: Record<string, unknown>): boolean {
return typeof payload.thought === "boolean"
}
export function handlePicoMessage(
message: PicoMessage,
expectedSessionId: string,
@@ -25,6 +38,7 @@ export function handlePicoMessage(
case "message.create": {
const content = (payload.content as string) || ""
const messageId = (payload.message_id as string) || `pico-${Date.now()}`
const kind = parseAssistantMessageKind(payload)
const timestamp =
message.timestamp !== undefined &&
Number.isFinite(Number(message.timestamp))
@@ -38,6 +52,7 @@ export function handlePicoMessage(
id: messageId,
role: "assistant",
content,
kind,
timestamp,
},
],
@@ -49,13 +64,21 @@ export function handlePicoMessage(
case "message.update": {
const content = (payload.content as string) || ""
const messageId = payload.message_id as string
const hasKind = hasAssistantKindPayload(payload)
const kind = parseAssistantMessageKind(payload)
if (!messageId) {
break
}
updateChatStore((prev) => ({
messages: prev.messages.map((msg) =>
msg.id === messageId ? { ...msg, content } : msg,
msg.id === messageId
? {
...msg,
content,
...(hasKind ? { kind } : {}),
}
: msg,
),
}))
break
+1
View File
@@ -47,6 +47,7 @@
"step3": "Preparing response...",
"step4": "Almost there..."
},
"reasoningLabel": "Reasoning",
"history": "History",
"noHistory": "No chat history yet",
"historyLoadFailed": "Failed to load chat history",
+1
View File
@@ -47,6 +47,7 @@
"step3": "准备回复...",
"step4": "马上就好..."
},
"reasoningLabel": "思考",
"history": "历史记录",
"noHistory": "暂无对话历史",
"historyLoadFailed": "加载历史记录失败",
+3
View File
@@ -11,11 +11,14 @@ export interface ChatAttachment {
filename?: string
}
export type AssistantMessageKind = "normal" | "thought"
export interface ChatMessage {
id: string
role: "user" | "assistant"
content: string
timestamp: number | string
kind?: AssistantMessageKind
attachments?: ChatAttachment[]
}