mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
Merge pull request #2475 from lc6464/fix/issue-2448-separate-thought-message
feat(gemini,pico): separate thought messages and add native Gemini provider
This commit is contained in:
+48
-7
@@ -105,6 +105,8 @@ const (
|
||||
toolLimitResponse = "I've reached `max_tool_iterations` without a final response. Increase `max_tool_iterations` in config.json if this task needs more tool steps."
|
||||
handledToolResponseSummary = "Requested output delivered via tool attachment."
|
||||
sessionKeyAgentPrefix = "agent:"
|
||||
metadataKeyMessageKind = "message_kind"
|
||||
messageKindThought = "thought"
|
||||
metadataKeyAccountID = "account_id"
|
||||
metadataKeyGuildID = "guild_id"
|
||||
metadataKeyTeamID = "team_id"
|
||||
@@ -1622,6 +1624,41 @@ func (al *AgentLoop) targetReasoningChannelID(channelName string) (chatID string
|
||||
return ""
|
||||
}
|
||||
|
||||
func (al *AgentLoop) publishPicoReasoning(ctx context.Context, reasoningContent, chatID string) {
|
||||
if reasoningContent == "" || chatID == "" {
|
||||
return
|
||||
}
|
||||
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
pubCtx, pubCancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
defer pubCancel()
|
||||
|
||||
if err := al.bus.PublishOutbound(pubCtx, bus.OutboundMessage{
|
||||
Channel: "pico",
|
||||
ChatID: chatID,
|
||||
Content: reasoningContent,
|
||||
Metadata: map[string]string{
|
||||
metadataKeyMessageKind: messageKindThought,
|
||||
},
|
||||
}); err != nil {
|
||||
if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) ||
|
||||
errors.Is(err, bus.ErrBusClosed) {
|
||||
logger.DebugCF("agent", "Pico reasoning publish skipped (timeout/cancel)", map[string]any{
|
||||
"channel": "pico",
|
||||
"error": err.Error(),
|
||||
})
|
||||
} else {
|
||||
logger.WarnCF("agent", "Failed to publish pico reasoning (best-effort)", map[string]any{
|
||||
"channel": "pico",
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (al *AgentLoop) handleReasoning(
|
||||
ctx context.Context,
|
||||
reasoningContent, channelName, channelID string,
|
||||
@@ -2223,12 +2260,16 @@ turnLoop:
|
||||
if reasoningContent == "" {
|
||||
reasoningContent = response.ReasoningContent
|
||||
}
|
||||
go al.handleReasoning(
|
||||
turnCtx,
|
||||
reasoningContent,
|
||||
ts.channel,
|
||||
al.targetReasoningChannelID(ts.channel),
|
||||
)
|
||||
if ts.channel == "pico" {
|
||||
go al.publishPicoReasoning(turnCtx, reasoningContent, ts.chatID)
|
||||
} else {
|
||||
go al.handleReasoning(
|
||||
turnCtx,
|
||||
reasoningContent,
|
||||
ts.channel,
|
||||
al.targetReasoningChannelID(ts.channel),
|
||||
)
|
||||
}
|
||||
al.emitEvent(
|
||||
EventKindLLMResponse,
|
||||
ts.eventMeta("runTurn", "turn.llm.response"),
|
||||
@@ -2277,7 +2318,7 @@ turnLoop:
|
||||
|
||||
if len(response.ToolCalls) == 0 || gracefulTerminal {
|
||||
responseContent := response.Content
|
||||
if responseContent == "" && response.ReasoningContent != "" {
|
||||
if responseContent == "" && response.ReasoningContent != "" && ts.channel != "pico" {
|
||||
responseContent = response.ReasoningContent
|
||||
}
|
||||
if steerMsgs := al.dequeueSteeringMessagesForScope(ts.sessionKey); len(steerMsgs) > 0 {
|
||||
|
||||
+57
-1
@@ -1921,7 +1921,7 @@ func TestProcessMessage_FallbackUsesPerCandidateProvider(t *testing.T) {
|
||||
},
|
||||
{
|
||||
ModelName: "gemma-fallback",
|
||||
Model: "gemini/gemma-3-27b-it",
|
||||
Model: "openrouter/gemma-3-27b-it",
|
||||
APIBase: fallbackServer.URL,
|
||||
APIKeys: config.SimpleSecureStrings("fallback-key"),
|
||||
Workspace: workspace,
|
||||
@@ -2660,6 +2660,62 @@ func TestProcessMessage_PublishesReasoningContentToReasoningChannel(t *testing.T
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessMessage_PicoPublishesReasoningAsThoughtMessage(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
ModelName: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider := &reasoningContentProvider{
|
||||
response: "final answer",
|
||||
reasoningContent: "thinking trace",
|
||||
}
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
|
||||
response, err := al.processMessage(context.Background(), bus.InboundMessage{
|
||||
Channel: "pico",
|
||||
SenderID: "user1",
|
||||
ChatID: "pico:test-session",
|
||||
Content: "hello",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("processMessage() error = %v", err)
|
||||
}
|
||||
if response != "final answer" {
|
||||
t.Fatalf("processMessage() response = %q, want %q", response, "final answer")
|
||||
}
|
||||
|
||||
var thoughtMsg *bus.OutboundMessage
|
||||
deadline := time.After(3 * time.Second)
|
||||
|
||||
for thoughtMsg == nil {
|
||||
select {
|
||||
case outbound := <-msgBus.OutboundChan():
|
||||
msg := outbound
|
||||
if msg.Content == "thinking trace" {
|
||||
thoughtMsg = &msg
|
||||
}
|
||||
case <-deadline:
|
||||
t.Fatal("expected thought outbound message for pico")
|
||||
}
|
||||
}
|
||||
|
||||
if thoughtMsg.Channel != "pico" || thoughtMsg.ChatID != "pico:test-session" {
|
||||
t.Fatalf("thought message route = %s/%s, want pico/pico:test-session", thoughtMsg.Channel, thoughtMsg.ChatID)
|
||||
}
|
||||
if thoughtMsg.Metadata[metadataKeyMessageKind] != messageKindThought {
|
||||
t.Fatalf("thought metadata kind = %q, want %q", thoughtMsg.Metadata[metadataKeyMessageKind], messageKindThought)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessHeartbeat_DoesNotPublishToolFeedback(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
heartbeatFile := filepath.Join(tmpDir, "heartbeat-task.txt")
|
||||
|
||||
@@ -242,7 +242,11 @@ func (c *PicoClientChannel) handleInbound(pc *picoConn, msg PicoMessage) {
|
||||
}
|
||||
|
||||
func (c *PicoClientChannel) handleServerMessage(pc *picoConn, msg PicoMessage) {
|
||||
content, _ := msg.Payload["content"].(string)
|
||||
if isThoughtPayload(msg.Payload) {
|
||||
return
|
||||
}
|
||||
|
||||
content, _ := msg.Payload[PayloadKeyContent].(string)
|
||||
if strings.TrimSpace(content) == "" {
|
||||
return
|
||||
}
|
||||
@@ -285,7 +289,7 @@ func (c *PicoClientChannel) Send(ctx context.Context, msg bus.OutboundMessage) (
|
||||
}
|
||||
|
||||
outMsg := newMessage(TypeMessageSend, map[string]any{
|
||||
"content": msg.Content,
|
||||
PayloadKeyContent: msg.Content,
|
||||
})
|
||||
outMsg.SessionID = strings.TrimPrefix(msg.ChatID, "pico_client:")
|
||||
return nil, pc.writeJSON(outMsg)
|
||||
|
||||
@@ -316,3 +316,67 @@ func TestPicoChannel_HandleMessageSend_AllowsMediaOnly(t *testing.T) {
|
||||
t.Fatal("timed out waiting for inbound media message")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsThoughtPayload(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
payload map[string]any
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "explicit thought bool",
|
||||
payload: map[string]any{PayloadKeyThought: true},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "thought false",
|
||||
payload: map[string]any{PayloadKeyThought: false},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "thought string ignored",
|
||||
payload: map[string]any{PayloadKeyThought: "true"},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "default normal",
|
||||
payload: map[string]any{PayloadKeyContent: "hello"},
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := isThoughtPayload(tt.payload); got != tt.want {
|
||||
t.Fatalf("isThoughtPayload() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPicoClientChannel_HandleServerMessage_IgnoresThought(t *testing.T) {
|
||||
mb := bus.NewMessageBus()
|
||||
ch, err := NewPicoClientChannel(config.PicoClientConfig{
|
||||
URL: "ws://localhost:8080/ws",
|
||||
}, mb)
|
||||
if err != nil {
|
||||
t.Fatalf("NewPicoClientChannel() error = %v", err)
|
||||
}
|
||||
|
||||
ch.ctx = context.Background()
|
||||
pc := &picoConn{sessionID: "sess-thought"}
|
||||
|
||||
ch.handleServerMessage(pc, PicoMessage{
|
||||
Type: TypeMessageCreate,
|
||||
Payload: map[string]any{
|
||||
PayloadKeyContent: "internal reasoning",
|
||||
PayloadKeyThought: true,
|
||||
},
|
||||
})
|
||||
|
||||
select {
|
||||
case msg := <-mb.InboundChan():
|
||||
t.Fatalf("expected no inbound publish for thought payload, got %+v", msg)
|
||||
case <-time.After(150 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
|
||||
@@ -39,6 +39,13 @@ var allowedInlineImageMIMETypes = map[string]struct{}{
|
||||
"image/bmp": {},
|
||||
}
|
||||
|
||||
func outboundMessageIsThought(metadata map[string]string) bool {
|
||||
if len(metadata) == 0 {
|
||||
return false
|
||||
}
|
||||
return strings.EqualFold(strings.TrimSpace(metadata["message_kind"]), MessageKindThought)
|
||||
}
|
||||
|
||||
// writeJSON sends a JSON message to the connection with write locking.
|
||||
func (pc *picoConn) writeJSON(v any) error {
|
||||
if pc.closed.Load() {
|
||||
@@ -247,9 +254,11 @@ func (c *PicoChannel) Send(ctx context.Context, msg bus.OutboundMessage) ([]stri
|
||||
if !c.IsRunning() {
|
||||
return nil, channels.ErrNotRunning
|
||||
}
|
||||
isThought := outboundMessageIsThought(msg.Metadata)
|
||||
|
||||
outMsg := newMessage(TypeMessageCreate, map[string]any{
|
||||
"content": msg.Content,
|
||||
PayloadKeyContent: msg.Content,
|
||||
PayloadKeyThought: isThought,
|
||||
})
|
||||
|
||||
return nil, c.broadcastToSession(msg.ChatID, outMsg)
|
||||
@@ -288,8 +297,9 @@ func (c *PicoChannel) SendPlaceholder(ctx context.Context, chatID string) (strin
|
||||
|
||||
msgID := uuid.New().String()
|
||||
outMsg := newMessage(TypeMessageCreate, map[string]any{
|
||||
"content": text,
|
||||
"message_id": msgID,
|
||||
PayloadKeyContent: text,
|
||||
PayloadKeyThought: false,
|
||||
"message_id": msgID,
|
||||
})
|
||||
|
||||
if err := c.broadcastToSession(chatID, outMsg); err != nil {
|
||||
|
||||
@@ -19,6 +19,11 @@ const (
|
||||
TypePong = "pong"
|
||||
|
||||
PicoTokenPrefix = "pico-"
|
||||
|
||||
PayloadKeyContent = "content"
|
||||
PayloadKeyThought = "thought"
|
||||
|
||||
MessageKindThought = "thought"
|
||||
)
|
||||
|
||||
// PicoMessage is the wire format for all Pico Protocol messages.
|
||||
@@ -39,6 +44,11 @@ func newMessage(msgType string, payload map[string]any) PicoMessage {
|
||||
}
|
||||
}
|
||||
|
||||
func isThoughtPayload(payload map[string]any) bool {
|
||||
thought, _ := payload[PayloadKeyThought].(bool)
|
||||
return thought
|
||||
}
|
||||
|
||||
func newErrorWithPayload(code, message string, extra map[string]any) PicoMessage {
|
||||
payload := map[string]any{
|
||||
"code": code,
|
||||
|
||||
@@ -389,6 +389,7 @@ type antigravityJSONResponse struct {
|
||||
Content struct {
|
||||
Parts []struct {
|
||||
Text string `json:"text,omitempty"`
|
||||
Thought bool `json:"thought,omitempty"`
|
||||
ThoughtSignature string `json:"thoughtSignature,omitempty"`
|
||||
ThoughtSignatureSnake string `json:"thought_signature,omitempty"`
|
||||
FunctionCall *antigravityFunctionCall `json:"functionCall,omitempty"`
|
||||
@@ -406,6 +407,7 @@ type antigravityJSONResponse struct {
|
||||
|
||||
func (p *AntigravityProvider) parseSSEResponse(body string) (*LLMResponse, error) {
|
||||
var contentParts []string
|
||||
var reasoningParts []string
|
||||
var toolCalls []ToolCall
|
||||
var usage *UsageInfo
|
||||
var finishReason string
|
||||
@@ -433,7 +435,11 @@ func (p *AntigravityProvider) parseSSEResponse(body string) (*LLMResponse, error
|
||||
for _, candidate := range resp.Candidates {
|
||||
for _, part := range candidate.Content.Parts {
|
||||
if part.Text != "" {
|
||||
contentParts = append(contentParts, part.Text)
|
||||
if part.Thought {
|
||||
reasoningParts = append(reasoningParts, part.Text)
|
||||
} else {
|
||||
contentParts = append(contentParts, part.Text)
|
||||
}
|
||||
}
|
||||
if part.FunctionCall != nil {
|
||||
argumentsJSON, _ := json.Marshal(part.FunctionCall.Args)
|
||||
@@ -475,10 +481,11 @@ func (p *AntigravityProvider) parseSSEResponse(body string) (*LLMResponse, error
|
||||
}
|
||||
|
||||
return &LLMResponse{
|
||||
Content: strings.Join(contentParts, ""),
|
||||
ToolCalls: toolCalls,
|
||||
FinishReason: mappedFinish,
|
||||
Usage: usage,
|
||||
Content: strings.Join(contentParts, ""),
|
||||
ReasoningContent: strings.Join(reasoningParts, ""),
|
||||
ToolCalls: toolCalls,
|
||||
FinishReason: mappedFinish,
|
||||
Usage: usage,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -54,3 +54,27 @@ func TestResolveToolResponseNameInfersNameFromGeneratedCallID(t *testing.T) {
|
||||
t.Fatalf("expected inferred tool name search_docs, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseSSEResponse_SplitsThoughtAndVisibleContent(t *testing.T) {
|
||||
p := &AntigravityProvider{}
|
||||
body := "data: {\"response\":{\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"hidden reasoning\",\"thought\":true},{\"text\":\"visible answer\"}],\"role\":\"model\"},\"finishReason\":\"STOP\"}],\"usageMetadata\":{\"promptTokenCount\":8,\"candidatesTokenCount\":17,\"totalTokenCount\":216}}}\n" +
|
||||
"data: [DONE]\n"
|
||||
|
||||
resp, err := p.parseSSEResponse(body)
|
||||
if err != nil {
|
||||
t.Fatalf("parseSSEResponse() error = %v", err)
|
||||
}
|
||||
|
||||
if resp.Content != "visible answer" {
|
||||
t.Fatalf("Content = %q, want %q", resp.Content, "visible answer")
|
||||
}
|
||||
if resp.ReasoningContent != "hidden reasoning" {
|
||||
t.Fatalf("ReasoningContent = %q, want %q", resp.ReasoningContent, "hidden reasoning")
|
||||
}
|
||||
if resp.FinishReason != "stop" {
|
||||
t.Fatalf("FinishReason = %q, want %q", resp.FinishReason, "stop")
|
||||
}
|
||||
if resp.Usage == nil || resp.Usage.TotalTokens != 216 {
|
||||
t.Fatalf("Usage.TotalTokens = %v, want %d", resp.Usage, 216)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -114,7 +114,7 @@ func ResolveAPIBase(cfg *config.ModelConfig) string {
|
||||
|
||||
// CreateProviderFromConfig creates a provider based on the ModelConfig.
|
||||
// It uses the protocol prefix in the Model field to determine which provider to create.
|
||||
// Supported protocol families include OpenAI-compatible prefixes (e.g., openai, openrouter, groq, gemini),
|
||||
// Supported protocol families include OpenAI-compatible prefixes (e.g., openai, openrouter, groq),
|
||||
// Azure OpenAI, Amazon Bedrock, Anthropic (including messages), and various CLI/compatibility shims.
|
||||
// See the switch on protocol in this function for the authoritative list.
|
||||
// Returns the provider, the model ID (without protocol prefix), and any error.
|
||||
@@ -218,7 +218,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
|
||||
}
|
||||
return provider, modelID, nil
|
||||
|
||||
case "litellm", "lmstudio", "openrouter", "groq", "zhipu", "gemini", "nvidia", "venice",
|
||||
case "litellm", "lmstudio", "openrouter", "groq", "zhipu", "nvidia", "venice",
|
||||
"ollama", "moonshot", "shengsuanyun", "deepseek", "cerebras",
|
||||
"vivgrid", "volcengine", "vllm", "qwen", "qwen-intl", "qwen-international", "dashscope-intl",
|
||||
"qwen-us", "dashscope-us", "mistral", "avian", "longcat", "modelscope", "novita",
|
||||
@@ -242,6 +242,24 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
|
||||
cfg.CustomHeaders,
|
||||
), modelID, nil
|
||||
|
||||
case "gemini":
|
||||
if cfg.APIKey() == "" && cfg.APIBase == "" {
|
||||
return nil, "", fmt.Errorf("api_key or api_base is required for gemini protocol (model: %s)", cfg.Model)
|
||||
}
|
||||
apiBase := cfg.APIBase
|
||||
if apiBase == "" {
|
||||
apiBase = getDefaultAPIBase(protocol)
|
||||
}
|
||||
return NewGeminiProvider(
|
||||
cfg.APIKey(),
|
||||
apiBase,
|
||||
cfg.Proxy,
|
||||
userAgent,
|
||||
cfg.RequestTimeout,
|
||||
cfg.ExtraBody,
|
||||
cfg.CustomHeaders,
|
||||
), modelID, nil
|
||||
|
||||
case "minimax":
|
||||
// Minimax requires reasoning_split: true in the request body
|
||||
if cfg.APIKey() == "" && cfg.APIBase == "" {
|
||||
|
||||
@@ -434,6 +434,62 @@ func TestCreateProviderFromConfig_Antigravity(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateProviderFromConfig_Gemini(t *testing.T) {
|
||||
cfg := &config.ModelConfig{
|
||||
ModelName: "test-gemini",
|
||||
Model: "gemini/gemini-2.5-flash",
|
||||
}
|
||||
cfg.SetAPIKey("test-key")
|
||||
|
||||
provider, modelID, err := CreateProviderFromConfig(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateProviderFromConfig() error = %v", err)
|
||||
}
|
||||
if provider == nil {
|
||||
t.Fatal("CreateProviderFromConfig() returned nil provider")
|
||||
}
|
||||
if modelID != "gemini-2.5-flash" {
|
||||
t.Errorf("modelID = %q, want %q", modelID, "gemini-2.5-flash")
|
||||
}
|
||||
if _, ok := provider.(*GeminiProvider); !ok {
|
||||
t.Fatalf("expected *GeminiProvider, got %T", provider)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateProviderFromConfig_GeminiMissingAPIKey(t *testing.T) {
|
||||
cfg := &config.ModelConfig{
|
||||
ModelName: "test-gemini-no-key",
|
||||
Model: "gemini/gemini-2.5-flash",
|
||||
}
|
||||
|
||||
_, _, err := CreateProviderFromConfig(cfg)
|
||||
if err == nil {
|
||||
t.Fatal("CreateProviderFromConfig() expected error for missing gemini API key")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateProviderFromConfig_GeminiCustomAPIBaseWithoutKey(t *testing.T) {
|
||||
cfg := &config.ModelConfig{
|
||||
ModelName: "test-gemini-custom-base",
|
||||
Model: "gemini/gemini-2.5-flash",
|
||||
APIBase: "https://proxy.example.com/v1beta",
|
||||
}
|
||||
|
||||
provider, modelID, err := CreateProviderFromConfig(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateProviderFromConfig() error = %v", err)
|
||||
}
|
||||
if provider == nil {
|
||||
t.Fatal("CreateProviderFromConfig() returned nil provider")
|
||||
}
|
||||
if modelID != "gemini-2.5-flash" {
|
||||
t.Errorf("modelID = %q, want %q", modelID, "gemini-2.5-flash")
|
||||
}
|
||||
if _, ok := provider.(*GeminiProvider); !ok {
|
||||
t.Fatalf("expected *GeminiProvider, got %T", provider)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateProviderFromConfig_ClaudeCLI(t *testing.T) {
|
||||
cfg := &config.ModelConfig{
|
||||
ModelName: "test-claude-cli",
|
||||
|
||||
@@ -0,0 +1,796 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/providers/common"
|
||||
)
|
||||
|
||||
const (
|
||||
geminiDefaultAPIBase = "https://generativelanguage.googleapis.com/v1beta"
|
||||
geminiDefaultModel = "gemini-2.0-flash"
|
||||
)
|
||||
|
||||
type GeminiProvider struct {
|
||||
apiKey string
|
||||
apiBase string
|
||||
httpClient *http.Client
|
||||
extraBody map[string]any
|
||||
customHeaders map[string]string
|
||||
userAgent string
|
||||
}
|
||||
|
||||
func NewGeminiProvider(
|
||||
apiKey string,
|
||||
apiBase string,
|
||||
proxy string,
|
||||
userAgent string,
|
||||
requestTimeoutSeconds int,
|
||||
extraBody map[string]any,
|
||||
customHeaders map[string]string,
|
||||
) *GeminiProvider {
|
||||
if strings.TrimSpace(apiBase) == "" {
|
||||
apiBase = geminiDefaultAPIBase
|
||||
}
|
||||
client := common.NewHTTPClient(proxy)
|
||||
if requestTimeoutSeconds > 0 {
|
||||
client.Timeout = time.Duration(requestTimeoutSeconds) * time.Second
|
||||
}
|
||||
|
||||
return &GeminiProvider{
|
||||
apiKey: strings.TrimSpace(apiKey),
|
||||
apiBase: strings.TrimRight(strings.TrimSpace(apiBase), "/"),
|
||||
httpClient: client,
|
||||
extraBody: cloneAnyMap(extraBody),
|
||||
customHeaders: cloneStringMap(customHeaders),
|
||||
userAgent: strings.TrimSpace(userAgent),
|
||||
}
|
||||
}
|
||||
|
||||
func (p *GeminiProvider) GetDefaultModel() string {
|
||||
return geminiDefaultModel
|
||||
}
|
||||
|
||||
func (p *GeminiProvider) SupportsThinking() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (p *GeminiProvider) Chat(
|
||||
ctx context.Context,
|
||||
messages []Message,
|
||||
tools []ToolDefinition,
|
||||
model string,
|
||||
options map[string]any,
|
||||
) (*LLMResponse, error) {
|
||||
if p.apiBase == "" {
|
||||
return nil, fmt.Errorf("API base not configured")
|
||||
}
|
||||
|
||||
model = normalizeGeminiModel(model)
|
||||
requestBody := p.buildRequestBody(messages, tools, model, options)
|
||||
jsonData, err := json.Marshal(requestBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/models/%s:generateContent", p.apiBase, model)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(jsonData))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
p.applyHeaders(req)
|
||||
|
||||
resp, err := p.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, common.HandleErrorResponse(resp, p.apiBase)
|
||||
}
|
||||
|
||||
var apiResp geminiGenerateContentResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&apiResp); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode response: %w", err)
|
||||
}
|
||||
|
||||
return parseGeminiResponse(&apiResp), nil
|
||||
}
|
||||
|
||||
func (p *GeminiProvider) ChatStream(
|
||||
ctx context.Context,
|
||||
messages []Message,
|
||||
tools []ToolDefinition,
|
||||
model string,
|
||||
options map[string]any,
|
||||
onChunk func(accumulated string),
|
||||
) (*LLMResponse, error) {
|
||||
if p.apiBase == "" {
|
||||
return nil, fmt.Errorf("API base not configured")
|
||||
}
|
||||
|
||||
model = normalizeGeminiModel(model)
|
||||
requestBody := p.buildRequestBody(messages, tools, model, options)
|
||||
jsonData, err := json.Marshal(requestBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/models/%s:streamGenerateContent?alt=sse", p.apiBase, model)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(jsonData))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
p.applyHeaders(req)
|
||||
req.Header.Set("Accept", "text/event-stream")
|
||||
|
||||
// Streaming should not use a whole-request timeout; context cancellation is the guard.
|
||||
streamClient := &http.Client{Transport: p.httpClient.Transport}
|
||||
resp, err := streamClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, common.HandleErrorResponse(resp, p.apiBase)
|
||||
}
|
||||
|
||||
return parseGeminiStreamResponse(ctx, resp.Body, onChunk)
|
||||
}
|
||||
|
||||
func (p *GeminiProvider) applyHeaders(req *http.Request) {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
if p.apiKey != "" {
|
||||
req.Header.Set("X-Goog-Api-Key", p.apiKey)
|
||||
}
|
||||
if p.userAgent != "" {
|
||||
req.Header.Set("User-Agent", p.userAgent)
|
||||
}
|
||||
for k, v := range p.customHeaders {
|
||||
if strings.TrimSpace(k) == "" {
|
||||
continue
|
||||
}
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *GeminiProvider) buildRequestBody(
|
||||
messages []Message,
|
||||
tools []ToolDefinition,
|
||||
model string,
|
||||
options map[string]any,
|
||||
) map[string]any {
|
||||
contents := make([]geminiContent, 0, len(messages))
|
||||
toolCallNames := make(map[string]string)
|
||||
systemPrompts := make([]string, 0, 1)
|
||||
|
||||
for _, msg := range messages {
|
||||
switch msg.Role {
|
||||
case "system":
|
||||
if strings.TrimSpace(msg.Content) != "" {
|
||||
systemPrompts = append(systemPrompts, msg.Content)
|
||||
}
|
||||
|
||||
case "user":
|
||||
if msg.ToolCallID != "" {
|
||||
toolName := resolveToolResponseName(msg.ToolCallID, toolCallNames)
|
||||
contents = append(contents, geminiContent{
|
||||
Role: "user",
|
||||
Parts: []geminiPart{{
|
||||
FunctionResponse: buildGeminiFunctionResponse(toolName, msg.ToolCallID, msg.Content, msg.Media),
|
||||
}},
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
parts := make([]geminiPart, 0, 1+len(msg.Media))
|
||||
if strings.TrimSpace(msg.Content) != "" {
|
||||
parts = append(parts, geminiPart{Text: msg.Content})
|
||||
}
|
||||
parts = append(parts, buildInlineMediaParts(msg.Media)...)
|
||||
if len(parts) > 0 {
|
||||
contents = append(contents, geminiContent{Role: "user", Parts: parts})
|
||||
}
|
||||
|
||||
case "assistant":
|
||||
content := geminiContent{Role: "model"}
|
||||
if strings.TrimSpace(msg.Content) != "" {
|
||||
content.Parts = append(content.Parts, geminiPart{Text: msg.Content})
|
||||
}
|
||||
for _, tc := range msg.ToolCalls {
|
||||
toolName, toolArgs, thoughtSignature := normalizeStoredToolCall(tc)
|
||||
if toolName == "" {
|
||||
continue
|
||||
}
|
||||
if tc.ID != "" {
|
||||
toolCallNames[tc.ID] = toolName
|
||||
}
|
||||
part := geminiPart{
|
||||
FunctionCall: &geminiFunctionCall{
|
||||
Name: toolName,
|
||||
Args: toolArgs,
|
||||
ID: tc.ID,
|
||||
},
|
||||
}
|
||||
if thoughtSignature != "" {
|
||||
part.ThoughtSignature = thoughtSignature
|
||||
}
|
||||
content.Parts = append(content.Parts, part)
|
||||
}
|
||||
if len(content.Parts) > 0 {
|
||||
contents = append(contents, content)
|
||||
}
|
||||
|
||||
case "tool":
|
||||
toolName := resolveToolResponseName(msg.ToolCallID, toolCallNames)
|
||||
contents = append(contents, geminiContent{
|
||||
Role: "user",
|
||||
Parts: []geminiPart{{
|
||||
FunctionResponse: buildGeminiFunctionResponse(toolName, msg.ToolCallID, msg.Content, msg.Media),
|
||||
}},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
body := map[string]any{
|
||||
"contents": contents,
|
||||
}
|
||||
if len(systemPrompts) > 0 {
|
||||
systemParts := make([]geminiPart, 0, len(systemPrompts))
|
||||
for _, prompt := range systemPrompts {
|
||||
systemParts = append(systemParts, geminiPart{Text: prompt})
|
||||
}
|
||||
body["systemInstruction"] = &geminiContent{Parts: systemParts}
|
||||
}
|
||||
|
||||
if len(tools) > 0 {
|
||||
funcDecls := make([]geminiFunctionDeclaration, 0, len(tools))
|
||||
for _, t := range tools {
|
||||
if t.Type != "function" {
|
||||
continue
|
||||
}
|
||||
funcDecls = append(funcDecls, geminiFunctionDeclaration{
|
||||
Name: t.Function.Name,
|
||||
Description: t.Function.Description,
|
||||
Parameters: sanitizeSchemaForGemini(t.Function.Parameters),
|
||||
})
|
||||
}
|
||||
if len(funcDecls) > 0 {
|
||||
body["tools"] = []geminiTool{{FunctionDeclarations: funcDecls}}
|
||||
}
|
||||
}
|
||||
|
||||
generationConfig := make(map[string]any)
|
||||
if val, ok := options["max_tokens"]; ok {
|
||||
if maxTokens, ok := val.(int); ok && maxTokens > 0 {
|
||||
generationConfig["maxOutputTokens"] = maxTokens
|
||||
} else if maxTokens, ok := val.(float64); ok && maxTokens > 0 {
|
||||
generationConfig["maxOutputTokens"] = int(maxTokens)
|
||||
}
|
||||
}
|
||||
if temp, ok := options["temperature"].(float64); ok {
|
||||
generationConfig["temperature"] = temp
|
||||
}
|
||||
|
||||
if thinkingConfig := buildGeminiThinkingConfig(model, options); len(thinkingConfig) > 0 {
|
||||
generationConfig["thinkingConfig"] = thinkingConfig
|
||||
}
|
||||
|
||||
if len(generationConfig) > 0 {
|
||||
body["generationConfig"] = generationConfig
|
||||
}
|
||||
|
||||
for k, v := range p.extraBody {
|
||||
body[k] = v
|
||||
}
|
||||
|
||||
return body
|
||||
}
|
||||
|
||||
func normalizeGeminiModel(model string) string {
|
||||
model = strings.TrimSpace(model)
|
||||
model = strings.TrimPrefix(model, "models/")
|
||||
if strings.Contains(model, "/") {
|
||||
_, modelID := ExtractProtocol(model)
|
||||
if modelID != "" {
|
||||
return modelID
|
||||
}
|
||||
}
|
||||
if model == "" {
|
||||
return geminiDefaultModel
|
||||
}
|
||||
return model
|
||||
}
|
||||
|
||||
func mapGeminiThinkingLevel(level string) string {
|
||||
switch strings.ToLower(strings.TrimSpace(level)) {
|
||||
case "minimal", "off":
|
||||
return "minimal"
|
||||
case "low":
|
||||
return "low"
|
||||
case "medium":
|
||||
return "medium"
|
||||
case "high", "xhigh", "adaptive":
|
||||
return "high"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func buildGeminiThinkingConfig(model string, options map[string]any) map[string]any {
|
||||
if !geminiModelSupportsThinkingConfig(model) {
|
||||
return nil
|
||||
}
|
||||
|
||||
config := map[string]any{}
|
||||
rawLevel, _ := options["thinking_level"].(string)
|
||||
rawLevel = strings.ToLower(strings.TrimSpace(rawLevel))
|
||||
if rawLevel == "" {
|
||||
// Align with agent-level default: unset means ThinkingOff.
|
||||
rawLevel = "off"
|
||||
}
|
||||
|
||||
includeThoughts := rawLevel != "off" && rawLevel != "minimal"
|
||||
config["includeThoughts"] = includeThoughts
|
||||
|
||||
if isGemini25Model(model) {
|
||||
if isGemini25ProModel(model) && (rawLevel == "off" || rawLevel == "minimal") {
|
||||
// Gemini 2.5 Pro cannot disable thinking; keep model-default thinking.
|
||||
return config
|
||||
}
|
||||
if budget, ok := mapGeminiThinkingBudget(rawLevel); ok {
|
||||
config["thinkingBudget"] = budget
|
||||
}
|
||||
return config
|
||||
}
|
||||
|
||||
if isGemini3ProModel(model) && (rawLevel == "off" || rawLevel == "minimal") {
|
||||
// Gemini 3.x Pro does not support minimal thinking level.
|
||||
return config
|
||||
}
|
||||
|
||||
if thinkingLevel := mapGeminiThinkingLevel(rawLevel); thinkingLevel != "" {
|
||||
config["thinkingLevel"] = thinkingLevel
|
||||
}
|
||||
return config
|
||||
}
|
||||
|
||||
func geminiModelSupportsThinkingConfig(model string) bool {
|
||||
lowerModel := strings.ToLower(strings.TrimSpace(model))
|
||||
return strings.Contains(lowerModel, "gemini-3") || isGemini25Model(lowerModel)
|
||||
}
|
||||
|
||||
func isGemini25Model(model string) bool {
|
||||
lowerModel := strings.ToLower(strings.TrimSpace(model))
|
||||
return strings.Contains(lowerModel, "gemini-2.5") || strings.Contains(lowerModel, "gemini-25")
|
||||
}
|
||||
|
||||
func isGemini25ProModel(model string) bool {
|
||||
lowerModel := strings.ToLower(strings.TrimSpace(model))
|
||||
return isGemini25Model(lowerModel) && strings.Contains(lowerModel, "pro")
|
||||
}
|
||||
|
||||
func isGemini3ProModel(model string) bool {
|
||||
lowerModel := strings.ToLower(strings.TrimSpace(model))
|
||||
return strings.Contains(lowerModel, "gemini-3") && strings.Contains(lowerModel, "pro")
|
||||
}
|
||||
|
||||
func mapGeminiThinkingBudget(level string) (int, bool) {
|
||||
level = strings.ToLower(strings.TrimSpace(level))
|
||||
if level == "" {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
switch level {
|
||||
case "adaptive":
|
||||
return -1, true
|
||||
case "minimal":
|
||||
return 0, true
|
||||
case "off":
|
||||
return 0, true
|
||||
case "low":
|
||||
return 1024, true
|
||||
case "medium":
|
||||
return 4096, true
|
||||
case "high":
|
||||
return 8192, true
|
||||
case "xhigh":
|
||||
return 16384, true
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
func parseGeminiResponse(resp *geminiGenerateContentResponse) *LLMResponse {
|
||||
contentParts := make([]string, 0)
|
||||
reasoningParts := make([]string, 0)
|
||||
toolCalls := make([]ToolCall, 0)
|
||||
finishReason := ""
|
||||
|
||||
for _, candidate := range resp.Candidates {
|
||||
for _, part := range candidate.Content.Parts {
|
||||
if part.Text != "" {
|
||||
if part.Thought {
|
||||
reasoningParts = append(reasoningParts, part.Text)
|
||||
} else {
|
||||
contentParts = append(contentParts, part.Text)
|
||||
}
|
||||
}
|
||||
if part.FunctionCall != nil {
|
||||
toolCalls = append(toolCalls, buildGeminiToolCall(part))
|
||||
}
|
||||
}
|
||||
if candidate.FinishReason != "" {
|
||||
finishReason = candidate.FinishReason
|
||||
}
|
||||
}
|
||||
|
||||
var usage *UsageInfo
|
||||
if resp.UsageMetadata.TotalTokenCount > 0 {
|
||||
usage = &UsageInfo{
|
||||
PromptTokens: resp.UsageMetadata.PromptTokenCount,
|
||||
CompletionTokens: resp.UsageMetadata.CandidatesTokenCount,
|
||||
TotalTokens: resp.UsageMetadata.TotalTokenCount,
|
||||
}
|
||||
}
|
||||
|
||||
return &LLMResponse{
|
||||
Content: strings.Join(contentParts, ""),
|
||||
ReasoningContent: strings.Join(reasoningParts, ""),
|
||||
ToolCalls: toolCalls,
|
||||
FinishReason: normalizeGeminiFinishReason(finishReason, len(toolCalls)),
|
||||
Usage: usage,
|
||||
}
|
||||
}
|
||||
|
||||
func parseGeminiStreamResponse(
|
||||
ctx context.Context,
|
||||
reader io.Reader,
|
||||
onChunk func(accumulated string),
|
||||
) (*LLMResponse, error) {
|
||||
var contentBuilder strings.Builder
|
||||
var reasoningBuilder strings.Builder
|
||||
var finishReason string
|
||||
var usage *UsageInfo
|
||||
|
||||
toolCallsByID := make(map[string]ToolCall)
|
||||
toolCallOrder := make([]string, 0)
|
||||
fallbackIndex := 0
|
||||
|
||||
scanner := bufio.NewScanner(reader)
|
||||
scanner.Buffer(make([]byte, 0, 1024*1024), 10*1024*1024)
|
||||
for scanner.Scan() {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
line := scanner.Text()
|
||||
if !strings.HasPrefix(line, "data: ") {
|
||||
continue
|
||||
}
|
||||
data := strings.TrimSpace(strings.TrimPrefix(line, "data: "))
|
||||
if data == "" {
|
||||
continue
|
||||
}
|
||||
if data == "[DONE]" {
|
||||
break
|
||||
}
|
||||
|
||||
var chunk geminiGenerateContentResponse
|
||||
if err := json.Unmarshal([]byte(data), &chunk); err != nil {
|
||||
return nil, fmt.Errorf("invalid gemini stream chunk: %w", err)
|
||||
}
|
||||
|
||||
for _, candidate := range chunk.Candidates {
|
||||
for _, part := range candidate.Content.Parts {
|
||||
if part.Text != "" {
|
||||
if part.Thought {
|
||||
reasoningBuilder.WriteString(part.Text)
|
||||
} else {
|
||||
contentBuilder.WriteString(part.Text)
|
||||
if onChunk != nil {
|
||||
onChunk(contentBuilder.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
if part.FunctionCall != nil {
|
||||
tc := buildGeminiToolCall(part)
|
||||
if strings.TrimSpace(tc.Name) == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
key := strings.TrimSpace(part.FunctionCall.ID)
|
||||
if key == "" {
|
||||
if len(toolCallOrder) > 0 {
|
||||
lastKey := toolCallOrder[len(toolCallOrder)-1]
|
||||
if lastTC, exists := toolCallsByID[lastKey]; exists && lastTC.Name == tc.Name {
|
||||
key = lastKey
|
||||
}
|
||||
}
|
||||
if key == "" {
|
||||
fallbackIndex++
|
||||
key = fmt.Sprintf("%s#%d", tc.Name, fallbackIndex)
|
||||
}
|
||||
}
|
||||
|
||||
tc.ID = key
|
||||
if _, exists := toolCallsByID[key]; !exists {
|
||||
toolCallOrder = append(toolCallOrder, key)
|
||||
}
|
||||
toolCallsByID[key] = tc
|
||||
}
|
||||
}
|
||||
if candidate.FinishReason != "" {
|
||||
finishReason = candidate.FinishReason
|
||||
}
|
||||
}
|
||||
|
||||
if chunk.UsageMetadata.TotalTokenCount > 0 {
|
||||
usage = &UsageInfo{
|
||||
PromptTokens: chunk.UsageMetadata.PromptTokenCount,
|
||||
CompletionTokens: chunk.UsageMetadata.CandidatesTokenCount,
|
||||
TotalTokens: chunk.UsageMetadata.TotalTokenCount,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
return nil, fmt.Errorf("streaming read error: %w", err)
|
||||
}
|
||||
|
||||
toolCalls := make([]ToolCall, 0, len(toolCallOrder))
|
||||
for _, key := range toolCallOrder {
|
||||
toolCalls = append(toolCalls, toolCallsByID[key])
|
||||
}
|
||||
|
||||
return &LLMResponse{
|
||||
Content: contentBuilder.String(),
|
||||
ReasoningContent: reasoningBuilder.String(),
|
||||
ToolCalls: toolCalls,
|
||||
FinishReason: normalizeGeminiFinishReason(finishReason, len(toolCalls)),
|
||||
Usage: usage,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func normalizeGeminiFinishReason(reason string, toolCalls int) string {
|
||||
if toolCalls > 0 {
|
||||
return "tool_calls"
|
||||
}
|
||||
|
||||
switch strings.ToUpper(strings.TrimSpace(reason)) {
|
||||
case "MAX_TOKENS":
|
||||
return "length"
|
||||
case "", "STOP":
|
||||
return "stop"
|
||||
default:
|
||||
return strings.ToLower(strings.TrimSpace(reason))
|
||||
}
|
||||
}
|
||||
|
||||
func buildGeminiToolCall(part geminiPart) ToolCall {
|
||||
if part.FunctionCall == nil {
|
||||
return ToolCall{}
|
||||
}
|
||||
|
||||
args := part.FunctionCall.Args
|
||||
if args == nil {
|
||||
args = make(map[string]any)
|
||||
}
|
||||
argsJSON, _ := json.Marshal(args)
|
||||
thoughtSignature := extractPartThoughtSignature(part.ThoughtSignature, part.ThoughtSignatureSnake)
|
||||
|
||||
toolCall := ToolCall{
|
||||
ID: part.FunctionCall.ID,
|
||||
Name: part.FunctionCall.Name,
|
||||
Arguments: args,
|
||||
ThoughtSignature: thoughtSignature,
|
||||
Function: &FunctionCall{
|
||||
Name: part.FunctionCall.Name,
|
||||
Arguments: string(argsJSON),
|
||||
ThoughtSignature: thoughtSignature,
|
||||
},
|
||||
}
|
||||
|
||||
if thoughtSignature != "" {
|
||||
toolCall.ExtraContent = &ExtraContent{
|
||||
Google: &GoogleExtra{ThoughtSignature: thoughtSignature},
|
||||
}
|
||||
}
|
||||
if strings.TrimSpace(toolCall.ID) == "" {
|
||||
toolCall.ID = fmt.Sprintf("call_%s_%d", toolCall.Name, time.Now().UnixNano())
|
||||
}
|
||||
|
||||
return toolCall
|
||||
}
|
||||
|
||||
func buildInlineMediaParts(media []string) []geminiPart {
|
||||
parts := make([]geminiPart, 0, len(media))
|
||||
for _, mediaURL := range media {
|
||||
mimeType, data, ok := parseBase64DataURL(mediaURL)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
parts = append(parts, geminiPart{
|
||||
InlineData: &geminiInlineData{
|
||||
MIMEType: mimeType,
|
||||
Data: data,
|
||||
},
|
||||
})
|
||||
}
|
||||
return parts
|
||||
}
|
||||
|
||||
func buildGeminiFunctionResponse(
|
||||
toolName string,
|
||||
toolCallID string,
|
||||
result string,
|
||||
media []string,
|
||||
) *geminiFunctionResponse {
|
||||
response := &geminiFunctionResponse{
|
||||
ID: toolCallID,
|
||||
Name: toolName,
|
||||
Response: map[string]any{
|
||||
"result": result,
|
||||
},
|
||||
}
|
||||
|
||||
if parts := buildFunctionResponseMediaParts(media); len(parts) > 0 {
|
||||
response.Parts = parts
|
||||
}
|
||||
|
||||
return response
|
||||
}
|
||||
|
||||
func buildFunctionResponseMediaParts(media []string) []geminiFunctionResponsePart {
|
||||
parts := make([]geminiFunctionResponsePart, 0, len(media))
|
||||
for i, mediaURL := range media {
|
||||
mimeType, data, ok := parseBase64DataURL(mediaURL)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
parts = append(parts, geminiFunctionResponsePart{
|
||||
InlineData: &geminiInlineData{
|
||||
MIMEType: mimeType,
|
||||
Data: data,
|
||||
DisplayName: defaultFunctionResponseDisplayName(mimeType, i+1),
|
||||
},
|
||||
})
|
||||
}
|
||||
return parts
|
||||
}
|
||||
|
||||
func defaultFunctionResponseDisplayName(mimeType string, index int) string {
|
||||
suffix := "bin"
|
||||
switch strings.ToLower(strings.TrimSpace(mimeType)) {
|
||||
case "image/png":
|
||||
suffix = "png"
|
||||
case "image/jpeg":
|
||||
suffix = "jpg"
|
||||
case "image/webp":
|
||||
suffix = "webp"
|
||||
case "application/pdf":
|
||||
suffix = "pdf"
|
||||
case "text/plain":
|
||||
suffix = "txt"
|
||||
}
|
||||
return fmt.Sprintf("attachment-%d.%s", index, suffix)
|
||||
}
|
||||
|
||||
func parseBase64DataURL(mediaURL string) (mimeType string, data string, ok bool) {
|
||||
if !strings.HasPrefix(mediaURL, "data:") {
|
||||
return "", "", false
|
||||
}
|
||||
|
||||
payload := strings.TrimPrefix(mediaURL, "data:")
|
||||
header, data, found := strings.Cut(payload, ",")
|
||||
if !found {
|
||||
return "", "", false
|
||||
}
|
||||
mimeType, params, _ := strings.Cut(header, ";")
|
||||
mimeType = strings.TrimSpace(mimeType)
|
||||
data = strings.TrimSpace(data)
|
||||
if mimeType == "" || data == "" {
|
||||
return "", "", false
|
||||
}
|
||||
if !strings.Contains(strings.ToLower(params), "base64") {
|
||||
return "", "", false
|
||||
}
|
||||
return mimeType, data, true
|
||||
}
|
||||
|
||||
func cloneAnyMap(in map[string]any) map[string]any {
|
||||
if len(in) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make(map[string]any, len(in))
|
||||
for k, v := range in {
|
||||
out[k] = v
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func cloneStringMap(in map[string]string) map[string]string {
|
||||
if len(in) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make(map[string]string, len(in))
|
||||
for k, v := range in {
|
||||
out[k] = v
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
type geminiGenerateContentResponse struct {
|
||||
Candidates []struct {
|
||||
Content struct {
|
||||
Role string `json:"role"`
|
||||
Parts []geminiPart `json:"parts"`
|
||||
} `json:"content"`
|
||||
FinishReason string `json:"finishReason"`
|
||||
} `json:"candidates"`
|
||||
UsageMetadata struct {
|
||||
PromptTokenCount int `json:"promptTokenCount"`
|
||||
CandidatesTokenCount int `json:"candidatesTokenCount"`
|
||||
TotalTokenCount int `json:"totalTokenCount"`
|
||||
} `json:"usageMetadata"`
|
||||
}
|
||||
|
||||
type geminiContent struct {
|
||||
Role string `json:"role,omitempty"`
|
||||
Parts []geminiPart `json:"parts"`
|
||||
}
|
||||
|
||||
type geminiPart struct {
|
||||
Text string `json:"text,omitempty"`
|
||||
Thought bool `json:"thought,omitempty"`
|
||||
ThoughtSignature string `json:"thoughtSignature,omitempty"`
|
||||
ThoughtSignatureSnake string `json:"thought_signature,omitempty"`
|
||||
InlineData *geminiInlineData `json:"inlineData,omitempty"`
|
||||
FunctionCall *geminiFunctionCall `json:"functionCall,omitempty"`
|
||||
FunctionResponse *geminiFunctionResponse `json:"functionResponse,omitempty"`
|
||||
}
|
||||
|
||||
type geminiInlineData struct {
|
||||
MIMEType string `json:"mimeType"`
|
||||
Data string `json:"data"`
|
||||
DisplayName string `json:"displayName,omitempty"`
|
||||
}
|
||||
|
||||
type geminiFunctionCall struct {
|
||||
ID string `json:"id,omitempty"`
|
||||
Name string `json:"name"`
|
||||
Args map[string]any `json:"args,omitempty"`
|
||||
}
|
||||
|
||||
type geminiFunctionResponse struct {
|
||||
ID string `json:"id,omitempty"`
|
||||
Name string `json:"name"`
|
||||
Response map[string]any `json:"response"`
|
||||
Parts []geminiFunctionResponsePart `json:"parts,omitempty"`
|
||||
}
|
||||
|
||||
type geminiFunctionResponsePart struct {
|
||||
InlineData *geminiInlineData `json:"inlineData,omitempty"`
|
||||
}
|
||||
|
||||
type geminiTool struct {
|
||||
FunctionDeclarations []geminiFunctionDeclaration `json:"functionDeclarations"`
|
||||
}
|
||||
|
||||
type geminiFunctionDeclaration struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Parameters any `json:"parameters,omitempty"`
|
||||
}
|
||||
@@ -0,0 +1,763 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGeminiProvider_ChatSeparatesThoughtAndToolCall(t *testing.T) {
|
||||
var capturedBody map[string]any
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
t.Fatalf("method = %s, want POST", r.Method)
|
||||
}
|
||||
if !strings.Contains(r.URL.Path, ":generateContent") {
|
||||
t.Fatalf("path = %s, expected generateContent endpoint", r.URL.Path)
|
||||
}
|
||||
if got := r.Header.Get("X-Goog-Api-Key"); got != "test-key" {
|
||||
t.Fatalf("X-Goog-Api-Key = %q, want %q", got, "test-key")
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&capturedBody); err != nil {
|
||||
t.Fatalf("decode request body: %v", err)
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"candidates": []any{
|
||||
map[string]any{
|
||||
"content": map[string]any{
|
||||
"role": "model",
|
||||
"parts": []any{
|
||||
map[string]any{"text": "hidden", "thought": true},
|
||||
map[string]any{"text": "visible"},
|
||||
map[string]any{
|
||||
"functionCall": map[string]any{
|
||||
"id": "call_1",
|
||||
"name": "search",
|
||||
"args": map[string]any{"q": "hi"},
|
||||
},
|
||||
"thoughtSignature": "sig-1",
|
||||
},
|
||||
},
|
||||
},
|
||||
"finishReason": "STOP",
|
||||
},
|
||||
},
|
||||
"usageMetadata": map[string]any{
|
||||
"promptTokenCount": 2,
|
||||
"candidatesTokenCount": 3,
|
||||
"totalTokenCount": 5,
|
||||
},
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
provider := NewGeminiProvider("test-key", server.URL, "", "picoclaw-test", 0, nil, nil)
|
||||
resp, err := provider.Chat(
|
||||
t.Context(),
|
||||
[]Message{{Role: "user", Content: "hello"}},
|
||||
nil,
|
||||
"gemini-3-flash-preview",
|
||||
map[string]any{"thinking_level": "high"},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error = %v", err)
|
||||
}
|
||||
if resp.Content != "visible" {
|
||||
t.Fatalf("Content = %q, want %q", resp.Content, "visible")
|
||||
}
|
||||
if resp.ReasoningContent != "hidden" {
|
||||
t.Fatalf("ReasoningContent = %q, want %q", resp.ReasoningContent, "hidden")
|
||||
}
|
||||
if resp.FinishReason != "tool_calls" {
|
||||
t.Fatalf("FinishReason = %q, want %q", resp.FinishReason, "tool_calls")
|
||||
}
|
||||
if resp.Usage == nil || resp.Usage.TotalTokens != 5 {
|
||||
t.Fatalf("Usage = %#v, expected total tokens = 5", resp.Usage)
|
||||
}
|
||||
if len(resp.ToolCalls) != 1 {
|
||||
t.Fatalf("ToolCalls len = %d, want 1", len(resp.ToolCalls))
|
||||
}
|
||||
if resp.ToolCalls[0].ID != "call_1" {
|
||||
t.Fatalf("ToolCall ID = %q, want %q", resp.ToolCalls[0].ID, "call_1")
|
||||
}
|
||||
if resp.ToolCalls[0].Name != "search" {
|
||||
t.Fatalf("ToolCall Name = %q, want %q", resp.ToolCalls[0].Name, "search")
|
||||
}
|
||||
if resp.ToolCalls[0].ThoughtSignature != "sig-1" {
|
||||
t.Fatalf("ToolCall ThoughtSignature = %q, want %q", resp.ToolCalls[0].ThoughtSignature, "sig-1")
|
||||
}
|
||||
if resp.ToolCalls[0].Function == nil || !strings.Contains(resp.ToolCalls[0].Function.Arguments, `"q":"hi"`) {
|
||||
t.Fatalf("ToolCall Function arguments = %#v, want q=hi", resp.ToolCalls[0].Function)
|
||||
}
|
||||
|
||||
generationConfig, ok := capturedBody["generationConfig"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("request missing generationConfig: %#v", capturedBody)
|
||||
}
|
||||
thinkingConfig, ok := generationConfig["thinkingConfig"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("request missing thinkingConfig: %#v", generationConfig)
|
||||
}
|
||||
if includeThoughts, ok := thinkingConfig["includeThoughts"].(bool); !ok || !includeThoughts {
|
||||
t.Fatalf("thinkingConfig.includeThoughts = %#v, want true", thinkingConfig["includeThoughts"])
|
||||
}
|
||||
if got := thinkingConfig["thinkingLevel"]; got != "high" {
|
||||
t.Fatalf("thinkingConfig.thinkingLevel = %#v, want %q", got, "high")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGeminiProvider_ChatStreamParsesThoughtTextAndToolCalls(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if !strings.Contains(r.URL.Path, ":streamGenerateContent") {
|
||||
t.Fatalf("path = %s, expected streamGenerateContent endpoint", r.URL.Path)
|
||||
}
|
||||
if got := r.URL.Query().Get("alt"); got != "sse" {
|
||||
t.Fatalf("alt query = %q, want %q", got, "sse")
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
t.Fatal("response writer is not flushable")
|
||||
}
|
||||
|
||||
chunks := []map[string]any{
|
||||
{
|
||||
"candidates": []any{map[string]any{
|
||||
"content": map[string]any{
|
||||
"parts": []any{
|
||||
map[string]any{"text": "think ", "thought": true},
|
||||
map[string]any{"text": "Hello "},
|
||||
},
|
||||
},
|
||||
}},
|
||||
},
|
||||
{
|
||||
"candidates": []any{map[string]any{
|
||||
"content": map[string]any{
|
||||
"parts": []any{
|
||||
map[string]any{"text": "World"},
|
||||
map[string]any{
|
||||
"functionCall": map[string]any{
|
||||
"id": "call_stream",
|
||||
"name": "search",
|
||||
"args": map[string]any{"q": "stream"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"finishReason": "STOP",
|
||||
}},
|
||||
"usageMetadata": map[string]any{
|
||||
"promptTokenCount": 1,
|
||||
"candidatesTokenCount": 2,
|
||||
"totalTokenCount": 3,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, chunk := range chunks {
|
||||
raw, err := json.Marshal(chunk)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal chunk: %v", err)
|
||||
}
|
||||
if _, err := fmt.Fprintf(w, "data: %s\n\n", raw); err != nil {
|
||||
t.Fatalf("write chunk: %v", err)
|
||||
}
|
||||
flusher.Flush()
|
||||
}
|
||||
_, _ = fmt.Fprint(w, "data: [DONE]\n\n")
|
||||
flusher.Flush()
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
provider := NewGeminiProvider("test-key", server.URL, "", "", 0, nil, nil)
|
||||
updates := make([]string, 0)
|
||||
resp, err := provider.ChatStream(
|
||||
t.Context(),
|
||||
[]Message{{Role: "user", Content: "hello"}},
|
||||
nil,
|
||||
"gemini-2.5-flash",
|
||||
nil,
|
||||
func(accumulated string) {
|
||||
updates = append(updates, accumulated)
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("ChatStream() error = %v", err)
|
||||
}
|
||||
if resp.Content != "Hello World" {
|
||||
t.Fatalf("Content = %q, want %q", resp.Content, "Hello World")
|
||||
}
|
||||
if resp.ReasoningContent != "think " {
|
||||
t.Fatalf("ReasoningContent = %q, want %q", resp.ReasoningContent, "think ")
|
||||
}
|
||||
if len(resp.ToolCalls) != 1 || resp.ToolCalls[0].ID != "call_stream" {
|
||||
t.Fatalf("ToolCalls = %#v, want single call_stream", resp.ToolCalls)
|
||||
}
|
||||
if resp.FinishReason != "tool_calls" {
|
||||
t.Fatalf("FinishReason = %q, want %q", resp.FinishReason, "tool_calls")
|
||||
}
|
||||
if resp.Usage == nil || resp.Usage.TotalTokens != 3 {
|
||||
t.Fatalf("Usage = %#v, expected total tokens = 3", resp.Usage)
|
||||
}
|
||||
if len(updates) < 2 || updates[len(updates)-1] != "Hello World" {
|
||||
t.Fatalf("stream updates = %#v, expected final accumulated text", updates)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGeminiProvider_ChatStreamSkipsEmptyDataFrames(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
t.Fatal("response writer is not flushable")
|
||||
}
|
||||
|
||||
_, _ = fmt.Fprint(w, "data: \n\n")
|
||||
flusher.Flush()
|
||||
|
||||
chunk := map[string]any{
|
||||
"candidates": []any{map[string]any{
|
||||
"content": map[string]any{
|
||||
"parts": []any{map[string]any{"text": "ok"}},
|
||||
},
|
||||
"finishReason": "STOP",
|
||||
}},
|
||||
}
|
||||
raw, err := json.Marshal(chunk)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal chunk: %v", err)
|
||||
}
|
||||
_, _ = fmt.Fprintf(w, "data: %s\n\n", raw)
|
||||
flusher.Flush()
|
||||
_, _ = fmt.Fprint(w, "data: [DONE]\n\n")
|
||||
flusher.Flush()
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
provider := NewGeminiProvider("test-key", server.URL, "", "", 0, nil, nil)
|
||||
resp, err := provider.ChatStream(
|
||||
t.Context(),
|
||||
[]Message{{Role: "user", Content: "hello"}},
|
||||
nil,
|
||||
"gemini-2.5-flash",
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("ChatStream() error = %v", err)
|
||||
}
|
||||
if resp.Content != "ok" {
|
||||
t.Fatalf("Content = %q, want %q", resp.Content, "ok")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGeminiProvider_ChatStreamReturnsErrorOnInvalidDataFrame(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
t.Fatal("response writer is not flushable")
|
||||
}
|
||||
|
||||
_, _ = fmt.Fprint(w, "data: {invalid-json}\n\n")
|
||||
flusher.Flush()
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
provider := NewGeminiProvider("test-key", server.URL, "", "", 0, nil, nil)
|
||||
_, err := provider.ChatStream(
|
||||
t.Context(),
|
||||
[]Message{{Role: "user", Content: "hello"}},
|
||||
nil,
|
||||
"gemini-2.5-flash",
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
if err == nil {
|
||||
t.Fatal("ChatStream() expected error for invalid SSE data frame")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "invalid gemini stream chunk") {
|
||||
t.Fatalf("error = %v, want contains %q", err, "invalid gemini stream chunk")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGeminiProvider_BuildRequestBody_UsesCamelCaseThoughtSignatureOnly(t *testing.T) {
|
||||
provider := NewGeminiProvider("test-key", "https://example.com/v1beta", "", "", 0, nil, nil)
|
||||
|
||||
body := provider.buildRequestBody(
|
||||
[]Message{{
|
||||
Role: "assistant",
|
||||
ToolCalls: []ToolCall{{
|
||||
ID: "call_1",
|
||||
Name: "search",
|
||||
Arguments: map[string]any{"q": "hello"},
|
||||
Function: &FunctionCall{
|
||||
Name: "search",
|
||||
Arguments: `{"q":"hello"}`,
|
||||
ThoughtSignature: "sig-1",
|
||||
},
|
||||
}},
|
||||
}},
|
||||
nil,
|
||||
"gemini-2.5-flash",
|
||||
nil,
|
||||
)
|
||||
|
||||
raw, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal request body: %v", err)
|
||||
}
|
||||
jsonBody := string(raw)
|
||||
|
||||
if !strings.Contains(jsonBody, `"thoughtSignature":"sig-1"`) {
|
||||
t.Fatalf("request body = %s, expected camelCase thoughtSignature", jsonBody)
|
||||
}
|
||||
if strings.Contains(jsonBody, `"thought_signature"`) {
|
||||
t.Fatalf("request body = %s, unexpected snake_case thought_signature", jsonBody)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGeminiProvider_ChatStreamCoalescesToolCallWithoutWireID(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
t.Fatal("response writer is not flushable")
|
||||
}
|
||||
|
||||
chunks := []map[string]any{
|
||||
{
|
||||
"candidates": []any{map[string]any{
|
||||
"content": map[string]any{
|
||||
"parts": []any{
|
||||
map[string]any{
|
||||
"functionCall": map[string]any{
|
||||
"name": "search",
|
||||
"args": map[string]any{"q": "first"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}},
|
||||
},
|
||||
{
|
||||
"candidates": []any{map[string]any{
|
||||
"content": map[string]any{
|
||||
"parts": []any{
|
||||
map[string]any{
|
||||
"functionCall": map[string]any{
|
||||
"name": "search",
|
||||
"args": map[string]any{"q": "second"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"finishReason": "STOP",
|
||||
}},
|
||||
},
|
||||
}
|
||||
|
||||
for _, chunk := range chunks {
|
||||
raw, err := json.Marshal(chunk)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal chunk: %v", err)
|
||||
}
|
||||
if _, err := fmt.Fprintf(w, "data: %s\n\n", raw); err != nil {
|
||||
t.Fatalf("write chunk: %v", err)
|
||||
}
|
||||
flusher.Flush()
|
||||
}
|
||||
_, _ = fmt.Fprint(w, "data: [DONE]\n\n")
|
||||
flusher.Flush()
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
provider := NewGeminiProvider("test-key", server.URL, "", "", 0, nil, nil)
|
||||
resp, err := provider.ChatStream(
|
||||
t.Context(),
|
||||
[]Message{{Role: "user", Content: "hello"}},
|
||||
nil,
|
||||
"gemini-2.5-flash",
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("ChatStream() error = %v", err)
|
||||
}
|
||||
if len(resp.ToolCalls) != 1 {
|
||||
t.Fatalf("ToolCalls len = %d, want 1", len(resp.ToolCalls))
|
||||
}
|
||||
tc := resp.ToolCalls[0]
|
||||
if tc.ID != "search#1" {
|
||||
t.Fatalf("ToolCall ID = %q, want %q", tc.ID, "search#1")
|
||||
}
|
||||
if tc.Name != "search" {
|
||||
t.Fatalf("ToolCall Name = %q, want %q", tc.Name, "search")
|
||||
}
|
||||
if argQ, ok := tc.Arguments["q"].(string); !ok || argQ != "second" {
|
||||
t.Fatalf("ToolCall Arguments = %#v, want q=second", tc.Arguments)
|
||||
}
|
||||
if resp.FinishReason != "tool_calls" {
|
||||
t.Fatalf("FinishReason = %q, want %q", resp.FinishReason, "tool_calls")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGeminiProvider_BuildRequestBodyIncludesMediaAndThinkingConfig(t *testing.T) {
|
||||
provider := NewGeminiProvider("test-key", "https://example.com/v1beta", "", "", 0, nil, nil)
|
||||
|
||||
body := provider.buildRequestBody(
|
||||
[]Message{{
|
||||
Role: "user",
|
||||
Content: "analyze attachments",
|
||||
Media: []string{
|
||||
"data:application/pdf;base64,UEZERGF0YQ==",
|
||||
"data:image/png;base64,aW1hZ2VEYXRh",
|
||||
},
|
||||
}},
|
||||
nil,
|
||||
"gemini-3-flash-preview",
|
||||
map[string]any{
|
||||
"thinking_level": "low",
|
||||
"max_tokens": 128,
|
||||
"temperature": 0.2,
|
||||
},
|
||||
)
|
||||
|
||||
contents, ok := body["contents"].([]geminiContent)
|
||||
if !ok || len(contents) != 1 {
|
||||
t.Fatalf("contents = %#v, want one gemini content", body["contents"])
|
||||
}
|
||||
parts := contents[0].Parts
|
||||
mimeSet := map[string]bool{}
|
||||
for _, part := range parts {
|
||||
if part.InlineData != nil {
|
||||
mimeSet[part.InlineData.MIMEType] = true
|
||||
}
|
||||
}
|
||||
if !mimeSet["application/pdf"] {
|
||||
t.Fatalf("inline media missing application/pdf: %#v", parts)
|
||||
}
|
||||
if !mimeSet["image/png"] {
|
||||
t.Fatalf("inline media missing image/png: %#v", parts)
|
||||
}
|
||||
|
||||
generationConfig, ok := body["generationConfig"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("generationConfig = %#v, want map", body["generationConfig"])
|
||||
}
|
||||
if got := generationConfig["maxOutputTokens"]; got != 128 {
|
||||
t.Fatalf("maxOutputTokens = %#v, want 128", got)
|
||||
}
|
||||
if got := generationConfig["temperature"]; got != 0.2 {
|
||||
t.Fatalf("temperature = %#v, want 0.2", got)
|
||||
}
|
||||
thinkingConfig, ok := generationConfig["thinkingConfig"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("thinkingConfig = %#v, want map", generationConfig["thinkingConfig"])
|
||||
}
|
||||
if includeThoughts, ok := thinkingConfig["includeThoughts"].(bool); !ok || !includeThoughts {
|
||||
t.Fatalf("includeThoughts = %#v, want true", thinkingConfig["includeThoughts"])
|
||||
}
|
||||
if got := thinkingConfig["thinkingLevel"]; got != "low" {
|
||||
t.Fatalf("thinkingLevel = %#v, want %q", got, "low")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGeminiProvider_BuildRequestBody_UsesThinkingBudgetForGemini25(t *testing.T) {
|
||||
provider := NewGeminiProvider("test-key", "https://example.com/v1beta", "", "", 0, nil, nil)
|
||||
body := provider.buildRequestBody(
|
||||
[]Message{{Role: "user", Content: "hello"}},
|
||||
nil,
|
||||
"gemini-2.5-flash",
|
||||
map[string]any{"thinking_level": "medium"},
|
||||
)
|
||||
|
||||
generationConfig, ok := body["generationConfig"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("generationConfig = %#v, want map", body["generationConfig"])
|
||||
}
|
||||
thinkingConfig, ok := generationConfig["thinkingConfig"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("thinkingConfig = %#v, want map", generationConfig["thinkingConfig"])
|
||||
}
|
||||
if got := thinkingConfig["thinkingBudget"]; got != 4096 {
|
||||
t.Fatalf("thinkingBudget = %#v, want 4096", got)
|
||||
}
|
||||
if _, hasLevel := thinkingConfig["thinkingLevel"]; hasLevel {
|
||||
t.Fatalf("thinkingLevel should not be set for Gemini 2.5: %#v", thinkingConfig)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGeminiProvider_BuildRequestBody_OmitsThinkingConfigForGemini20(t *testing.T) {
|
||||
provider := NewGeminiProvider("test-key", "https://example.com/v1beta", "", "", 0, nil, nil)
|
||||
body := provider.buildRequestBody(
|
||||
[]Message{{Role: "user", Content: "hello"}},
|
||||
nil,
|
||||
"gemini-2.0-flash-exp",
|
||||
map[string]any{"thinking_level": "high"},
|
||||
)
|
||||
|
||||
if _, ok := body["generationConfig"]; ok {
|
||||
t.Fatalf("generationConfig should be omitted for Gemini 2.0 when only thinking_level is set: %#v", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGeminiProvider_BuildRequestBody_DefaultsThinkingOffForGemini25(t *testing.T) {
|
||||
provider := NewGeminiProvider("test-key", "https://example.com/v1beta", "", "", 0, nil, nil)
|
||||
body := provider.buildRequestBody(
|
||||
[]Message{{Role: "user", Content: "hello"}},
|
||||
nil,
|
||||
"gemini-2.5-flash",
|
||||
nil,
|
||||
)
|
||||
|
||||
generationConfig, ok := body["generationConfig"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("generationConfig = %#v, want map", body["generationConfig"])
|
||||
}
|
||||
thinkingConfig, ok := generationConfig["thinkingConfig"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("thinkingConfig = %#v, want map", generationConfig["thinkingConfig"])
|
||||
}
|
||||
if got := thinkingConfig["thinkingBudget"]; got != 0 {
|
||||
t.Fatalf("thinkingBudget = %#v, want 0 for default/off", got)
|
||||
}
|
||||
if includeThoughts, ok := thinkingConfig["includeThoughts"].(bool); !ok || includeThoughts {
|
||||
t.Fatalf("includeThoughts = %#v, want false for default/off", thinkingConfig["includeThoughts"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestGeminiProvider_BuildRequestBody_DefaultsThinkingOffForGemini3(t *testing.T) {
|
||||
provider := NewGeminiProvider("test-key", "https://example.com/v1beta", "", "", 0, nil, nil)
|
||||
body := provider.buildRequestBody(
|
||||
[]Message{{Role: "user", Content: "hello"}},
|
||||
nil,
|
||||
"gemini-3-flash-preview",
|
||||
nil,
|
||||
)
|
||||
|
||||
generationConfig, ok := body["generationConfig"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("generationConfig = %#v, want map", body["generationConfig"])
|
||||
}
|
||||
thinkingConfig, ok := generationConfig["thinkingConfig"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("thinkingConfig = %#v, want map", generationConfig["thinkingConfig"])
|
||||
}
|
||||
if got := thinkingConfig["thinkingLevel"]; got != "minimal" {
|
||||
t.Fatalf("thinkingLevel = %#v, want minimal for default/off", got)
|
||||
}
|
||||
if includeThoughts, ok := thinkingConfig["includeThoughts"].(bool); !ok || includeThoughts {
|
||||
t.Fatalf("includeThoughts = %#v, want false for default/off", thinkingConfig["includeThoughts"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestGeminiProvider_BuildRequestBody_DefaultsThinkingOffForGemini25Pro(t *testing.T) {
|
||||
provider := NewGeminiProvider("test-key", "https://example.com/v1beta", "", "", 0, nil, nil)
|
||||
body := provider.buildRequestBody(
|
||||
[]Message{{Role: "user", Content: "hello"}},
|
||||
nil,
|
||||
"gemini-2.5-pro",
|
||||
nil,
|
||||
)
|
||||
|
||||
generationConfig, ok := body["generationConfig"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("generationConfig = %#v, want map", body["generationConfig"])
|
||||
}
|
||||
thinkingConfig, ok := generationConfig["thinkingConfig"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("thinkingConfig = %#v, want map", generationConfig["thinkingConfig"])
|
||||
}
|
||||
if includeThoughts, ok := thinkingConfig["includeThoughts"].(bool); !ok || includeThoughts {
|
||||
t.Fatalf("includeThoughts = %#v, want false for default/off", thinkingConfig["includeThoughts"])
|
||||
}
|
||||
if _, hasBudget := thinkingConfig["thinkingBudget"]; hasBudget {
|
||||
t.Fatalf("thinkingBudget should be omitted for Gemini 2.5 Pro default/off: %#v", thinkingConfig)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGeminiProvider_BuildRequestBody_DefaultsThinkingOffForGemini31Pro(t *testing.T) {
|
||||
provider := NewGeminiProvider("test-key", "https://example.com/v1beta", "", "", 0, nil, nil)
|
||||
body := provider.buildRequestBody(
|
||||
[]Message{{Role: "user", Content: "hello"}},
|
||||
nil,
|
||||
"gemini-3.1-pro",
|
||||
nil,
|
||||
)
|
||||
|
||||
generationConfig, ok := body["generationConfig"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("generationConfig = %#v, want map", body["generationConfig"])
|
||||
}
|
||||
thinkingConfig, ok := generationConfig["thinkingConfig"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("thinkingConfig = %#v, want map", generationConfig["thinkingConfig"])
|
||||
}
|
||||
if includeThoughts, ok := thinkingConfig["includeThoughts"].(bool); !ok || includeThoughts {
|
||||
t.Fatalf("includeThoughts = %#v, want false for default/off", thinkingConfig["includeThoughts"])
|
||||
}
|
||||
if _, hasLevel := thinkingConfig["thinkingLevel"]; hasLevel {
|
||||
t.Fatalf("thinkingLevel should be omitted for Gemini 3.1 Pro default/off: %#v", thinkingConfig)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGeminiProvider_BuildRequestBody_PreservesMultipleSystemMessages(t *testing.T) {
|
||||
provider := NewGeminiProvider("test-key", "https://example.com/v1beta", "", "", 0, nil, nil)
|
||||
body := provider.buildRequestBody(
|
||||
[]Message{
|
||||
{Role: "system", Content: "You are helpful."},
|
||||
{Role: "system", Content: "Be concise."},
|
||||
{Role: "user", Content: "hello"},
|
||||
},
|
||||
nil,
|
||||
"gemini-3-flash-preview",
|
||||
nil,
|
||||
)
|
||||
|
||||
systemInstruction, ok := body["systemInstruction"].(*geminiContent)
|
||||
if !ok || systemInstruction == nil {
|
||||
t.Fatalf("systemInstruction = %#v, want *geminiContent", body["systemInstruction"])
|
||||
}
|
||||
if len(systemInstruction.Parts) != 2 {
|
||||
t.Fatalf("systemInstruction.Parts len = %d, want 2", len(systemInstruction.Parts))
|
||||
}
|
||||
if systemInstruction.Parts[0].Text != "You are helpful." || systemInstruction.Parts[1].Text != "Be concise." {
|
||||
t.Fatalf("systemInstruction.Parts = %#v, want ordered system prompts", systemInstruction.Parts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGeminiProvider_BuildRequestBody_PreservesToolResponseMedia(t *testing.T) {
|
||||
provider := NewGeminiProvider("test-key", "https://example.com/v1beta", "", "", 0, nil, nil)
|
||||
body := provider.buildRequestBody(
|
||||
[]Message{
|
||||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []ToolCall{{
|
||||
ID: "call_1",
|
||||
Name: "load_image",
|
||||
Arguments: map[string]any{"path": "demo.png"},
|
||||
}},
|
||||
},
|
||||
{
|
||||
Role: "tool",
|
||||
ToolCallID: "call_1",
|
||||
Content: "tool result",
|
||||
Media: []string{
|
||||
"data:image/png;base64,aW1hZ2VEYXRh",
|
||||
"data:application/pdf;base64,UEZERGF0YQ==",
|
||||
},
|
||||
},
|
||||
},
|
||||
nil,
|
||||
"gemini-3-flash-preview",
|
||||
nil,
|
||||
)
|
||||
|
||||
contents, ok := body["contents"].([]geminiContent)
|
||||
if !ok || len(contents) != 2 {
|
||||
t.Fatalf("contents = %#v, want two content entries", body["contents"])
|
||||
}
|
||||
parts := contents[1].Parts
|
||||
if len(parts) != 1 || parts[0].FunctionResponse == nil {
|
||||
t.Fatalf("tool response part = %#v, want functionResponse", parts)
|
||||
}
|
||||
response := parts[0].FunctionResponse
|
||||
if response.Name != "load_image" {
|
||||
t.Fatalf("functionResponse.Name = %q, want %q", response.Name, "load_image")
|
||||
}
|
||||
if response.Response["result"] != "tool result" {
|
||||
t.Fatalf("functionResponse.Response = %#v, want result=tool result", response.Response)
|
||||
}
|
||||
if len(response.Parts) != 2 {
|
||||
t.Fatalf("functionResponse.Parts len = %d, want 2", len(response.Parts))
|
||||
}
|
||||
}
|
||||
|
||||
func TestGeminiProvider_ChatAllowsCustomAuthHeaderWithoutAPIKey(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if got := r.Header.Get("Authorization"); got != "Bearer test-token" {
|
||||
t.Fatalf("Authorization = %q, want %q", got, "Bearer test-token")
|
||||
}
|
||||
if got := r.Header.Get("X-Goog-Api-Key"); got != "" {
|
||||
t.Fatalf("X-Goog-Api-Key = %q, want empty", got)
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"candidates": []any{
|
||||
map[string]any{
|
||||
"content": map[string]any{
|
||||
"parts": []any{map[string]any{"text": "ok"}},
|
||||
},
|
||||
"finishReason": "STOP",
|
||||
},
|
||||
},
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
provider := NewGeminiProvider(
|
||||
"",
|
||||
server.URL,
|
||||
"",
|
||||
"",
|
||||
0,
|
||||
nil,
|
||||
map[string]string{"Authorization": "Bearer test-token"},
|
||||
)
|
||||
|
||||
resp, err := provider.Chat(
|
||||
t.Context(),
|
||||
[]Message{{Role: "user", Content: "hello"}},
|
||||
nil,
|
||||
"gemini-2.5-flash",
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error = %v", err)
|
||||
}
|
||||
if resp.Content != "ok" {
|
||||
t.Fatalf("Content = %q, want %q", resp.Content, "ok")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGeminiProvider_ChatAllowsMissingAPIKeyForCustomAPIBase(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if got := r.Header.Get("X-Goog-Api-Key"); got != "" {
|
||||
t.Fatalf("X-Goog-Api-Key = %q, want empty", got)
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"candidates": []any{
|
||||
map[string]any{
|
||||
"content": map[string]any{"parts": []any{map[string]any{"text": "ok"}}},
|
||||
"finishReason": "STOP",
|
||||
},
|
||||
},
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
provider := NewGeminiProvider("", server.URL, "", "", 0, nil, nil)
|
||||
resp, err := provider.Chat(
|
||||
t.Context(),
|
||||
[]Message{{Role: "user", Content: "hello"}},
|
||||
nil,
|
||||
"gemini-2.5-flash",
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error = %v", err)
|
||||
}
|
||||
if resp.Content != "ok" {
|
||||
t.Fatalf("Content = %q, want %q", resp.Content, "ok")
|
||||
}
|
||||
}
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"maps"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
@@ -181,9 +182,7 @@ func (p *Provider) buildRequestBody(
|
||||
|
||||
// Merge extra body fields configured per-provider/model.
|
||||
// These are injected last so they take precedence over defaults.
|
||||
for k, v := range p.extraBody {
|
||||
requestBody[k] = v
|
||||
}
|
||||
maps.Copy(requestBody, p.extraBody)
|
||||
|
||||
return requestBody
|
||||
}
|
||||
|
||||
@@ -281,6 +281,12 @@ func visibleSessionMessages(messages []providers.Message, toolFeedbackMaxArgsLen
|
||||
}
|
||||
|
||||
case "assistant":
|
||||
// Reasoning-only assistant messages are transient display artifacts and
|
||||
// should not be restored from session history.
|
||||
if assistantMessageTransientThought(msg) {
|
||||
continue
|
||||
}
|
||||
|
||||
toolSummaryMessages := visibleAssistantToolSummaryMessages(msg.ToolCalls, toolFeedbackMaxArgsLength)
|
||||
if len(toolSummaryMessages) > 0 {
|
||||
transcript = append(transcript, toolSummaryMessages...)
|
||||
@@ -309,6 +315,13 @@ func visibleSessionMessages(messages []providers.Message, toolFeedbackMaxArgsLen
|
||||
return transcript
|
||||
}
|
||||
|
||||
func assistantMessageTransientThought(msg providers.Message) bool {
|
||||
return strings.TrimSpace(msg.Content) == "" &&
|
||||
strings.TrimSpace(msg.ReasoningContent) != "" &&
|
||||
len(msg.ToolCalls) == 0 &&
|
||||
len(msg.Media) == 0
|
||||
}
|
||||
|
||||
func assistantMessageInternalOnly(msg providers.Message) bool {
|
||||
return strings.TrimSpace(msg.Content) == handledToolResponseSummaryText
|
||||
}
|
||||
|
||||
@@ -218,6 +218,59 @@ func TestHandleGetSession_JSONLStorage(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleGetSession_OmitsTransientThoughtMessages(t *testing.T) {
|
||||
configPath, cleanup := setupOAuthTestEnv(t)
|
||||
defer cleanup()
|
||||
|
||||
dir := sessionsTestDir(t, configPath)
|
||||
store, err := memory.NewJSONLStore(dir)
|
||||
if err != nil {
|
||||
t.Fatalf("NewJSONLStore() error = %v", err)
|
||||
}
|
||||
|
||||
sessionKey := picoSessionPrefix + "detail-transient-thought"
|
||||
for _, msg := range []providers.Message{
|
||||
{Role: "user", Content: "hello"},
|
||||
{Role: "assistant", ReasoningContent: "internal chain of thought"},
|
||||
{Role: "assistant", Content: "final visible answer"},
|
||||
} {
|
||||
if err := store.AddFullMessage(nil, sessionKey, msg); err != nil {
|
||||
t.Fatalf("AddFullMessage() error = %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
h := NewHandler(configPath)
|
||||
mux := http.NewServeMux()
|
||||
h.RegisterRoutes(mux)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/sessions/detail-transient-thought", nil)
|
||||
mux.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String())
|
||||
}
|
||||
|
||||
var resp struct {
|
||||
Messages []struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
} `json:"messages"`
|
||||
}
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("Unmarshal() error = %v", err)
|
||||
}
|
||||
if len(resp.Messages) != 2 {
|
||||
t.Fatalf("len(resp.Messages) = %d, want 2", len(resp.Messages))
|
||||
}
|
||||
if resp.Messages[0].Role != "user" || resp.Messages[0].Content != "hello" {
|
||||
t.Fatalf("first message = %#v, want user/hello", resp.Messages[0])
|
||||
}
|
||||
if resp.Messages[1].Role != "assistant" || resp.Messages[1].Content != "final visible answer" {
|
||||
t.Fatalf("second message = %#v, want assistant/final visible answer", resp.Messages[1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleGetSession_ReconstructsVisibleMessageToolOutput(t *testing.T) {
|
||||
configPath, cleanup := setupOAuthTestEnv(t)
|
||||
defer cleanup()
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import { IconCheck, IconCopy } from "@tabler/icons-react"
|
||||
import { IconBrain, IconCheck, IconCopy } from "@tabler/icons-react"
|
||||
import { useState } from "react"
|
||||
import { useTranslation } from "react-i18next"
|
||||
import ReactMarkdown from "react-markdown"
|
||||
import rehypeRaw from "rehype-raw"
|
||||
import rehypeSanitize from "rehype-sanitize"
|
||||
@@ -7,16 +8,20 @@ import remarkGfm from "remark-gfm"
|
||||
|
||||
import { Button } from "@/components/ui/button"
|
||||
import { formatMessageTime } from "@/hooks/use-pico-chat"
|
||||
import { cn } from "@/lib/utils"
|
||||
|
||||
interface AssistantMessageProps {
|
||||
content: string
|
||||
isThought?: boolean
|
||||
timestamp?: string | number
|
||||
}
|
||||
|
||||
export function AssistantMessage({
|
||||
content,
|
||||
isThought = false,
|
||||
timestamp = "",
|
||||
}: AssistantMessageProps) {
|
||||
const { t } = useTranslation()
|
||||
const [isCopied, setIsCopied] = useState(false)
|
||||
const formattedTimestamp =
|
||||
timestamp !== "" ? formatMessageTime(timestamp) : ""
|
||||
@@ -33,6 +38,12 @@ export function AssistantMessage({
|
||||
<div className="text-muted-foreground flex items-center justify-between gap-2 px-1 text-xs opacity-70">
|
||||
<div className="flex items-center gap-2">
|
||||
<span>PicoClaw</span>
|
||||
{isThought && (
|
||||
<span className="inline-flex items-center gap-1 rounded-full border border-amber-300/80 bg-amber-100/80 px-2 py-0.5 text-[11px] font-medium text-amber-800 dark:border-amber-500/40 dark:bg-amber-500/15 dark:text-amber-200">
|
||||
<IconBrain className="size-3" />
|
||||
<span>{t("chat.reasoningLabel")}</span>
|
||||
</span>
|
||||
)}
|
||||
{formattedTimestamp && (
|
||||
<>
|
||||
<span className="opacity-50">•</span>
|
||||
@@ -42,8 +53,22 @@ export function AssistantMessage({
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="bg-card text-card-foreground relative overflow-hidden rounded-xl border">
|
||||
<div className="prose dark:prose-invert prose-p:my-2 prose-pre:my-2 prose-pre:overflow-x-auto prose-pre:rounded-lg prose-pre:border prose-pre:bg-zinc-950 prose-pre:p-3 max-w-none p-4 text-[15px] leading-relaxed [overflow-wrap:anywhere] break-words">
|
||||
<div
|
||||
className={cn(
|
||||
"relative overflow-hidden rounded-xl border",
|
||||
isThought
|
||||
? "border-amber-200/90 bg-amber-50/70 text-amber-950 dark:border-amber-500/35 dark:bg-amber-500/10 dark:text-amber-100"
|
||||
: "bg-card text-card-foreground",
|
||||
)}
|
||||
>
|
||||
<div
|
||||
className={cn(
|
||||
"prose dark:prose-invert prose-pre:my-2 prose-pre:overflow-x-auto prose-pre:rounded-lg prose-pre:border prose-pre:bg-zinc-950 prose-pre:p-3 max-w-none [overflow-wrap:anywhere] break-words",
|
||||
isThought
|
||||
? "prose-p:my-1.5 p-3 text-[13px] leading-relaxed opacity-90"
|
||||
: "prose-p:my-2 p-4 text-[15px] leading-relaxed",
|
||||
)}
|
||||
>
|
||||
<ReactMarkdown
|
||||
remarkPlugins={[remarkGfm]}
|
||||
rehypePlugins={[rehypeRaw, rehypeSanitize]}
|
||||
@@ -54,7 +79,12 @@ export function AssistantMessage({
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
className="bg-background/50 hover:bg-background/80 absolute top-2 right-2 h-7 w-7 opacity-0 transition-opacity group-hover:opacity-100"
|
||||
className={cn(
|
||||
"absolute top-2 right-2 h-7 w-7 opacity-0 transition-opacity group-hover:opacity-100",
|
||||
isThought
|
||||
? "bg-amber-100/70 hover:bg-amber-200/80 dark:bg-amber-500/20 dark:hover:bg-amber-400/30"
|
||||
: "bg-background/50 hover:bg-background/80",
|
||||
)}
|
||||
onClick={handleCopy}
|
||||
>
|
||||
{isCopied ? (
|
||||
|
||||
@@ -247,6 +247,7 @@ export function ChatPage() {
|
||||
{msg.role === "assistant" ? (
|
||||
<AssistantMessage
|
||||
content={msg.content}
|
||||
isThought={msg.kind === "thought"}
|
||||
timestamp={msg.timestamp}
|
||||
/>
|
||||
) : (
|
||||
|
||||
@@ -24,6 +24,7 @@ export async function loadSessionMessages(
|
||||
id: `hist-${index}-${Date.now()}`,
|
||||
role: message.role,
|
||||
content: message.content,
|
||||
kind: message.role === "assistant" ? "normal" : undefined,
|
||||
attachments: toChatAttachments(message.media),
|
||||
timestamp: fallbackTime,
|
||||
}))
|
||||
@@ -50,7 +51,7 @@ function messageSignature(message: ChatMessage): string {
|
||||
|
||||
return `${message.role}\u0000${message.content}\u0000${normalizeMessageTimestamp(
|
||||
message.timestamp,
|
||||
)}\u0000${attachmentSignature}`
|
||||
)}\u0000${message.kind ?? ""}\u0000${attachmentSignature}`
|
||||
}
|
||||
|
||||
function comparableTimestamp(timestamp: number | string): number {
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
import { toast } from "sonner"
|
||||
|
||||
import { normalizeUnixTimestamp } from "@/features/chat/state"
|
||||
import { updateChatStore } from "@/store/chat"
|
||||
import {
|
||||
type AssistantMessageKind,
|
||||
updateChatStore,
|
||||
} from "@/store/chat"
|
||||
|
||||
export interface PicoMessage {
|
||||
type: string
|
||||
@@ -11,6 +14,16 @@ export interface PicoMessage {
|
||||
payload?: Record<string, unknown>
|
||||
}
|
||||
|
||||
function parseAssistantMessageKind(
|
||||
payload: Record<string, unknown>,
|
||||
): AssistantMessageKind {
|
||||
return payload.thought === true ? "thought" : "normal"
|
||||
}
|
||||
|
||||
function hasAssistantKindPayload(payload: Record<string, unknown>): boolean {
|
||||
return typeof payload.thought === "boolean"
|
||||
}
|
||||
|
||||
export function handlePicoMessage(
|
||||
message: PicoMessage,
|
||||
expectedSessionId: string,
|
||||
@@ -25,6 +38,7 @@ export function handlePicoMessage(
|
||||
case "message.create": {
|
||||
const content = (payload.content as string) || ""
|
||||
const messageId = (payload.message_id as string) || `pico-${Date.now()}`
|
||||
const kind = parseAssistantMessageKind(payload)
|
||||
const timestamp =
|
||||
message.timestamp !== undefined &&
|
||||
Number.isFinite(Number(message.timestamp))
|
||||
@@ -38,6 +52,7 @@ export function handlePicoMessage(
|
||||
id: messageId,
|
||||
role: "assistant",
|
||||
content,
|
||||
kind,
|
||||
timestamp,
|
||||
},
|
||||
],
|
||||
@@ -49,13 +64,21 @@ export function handlePicoMessage(
|
||||
case "message.update": {
|
||||
const content = (payload.content as string) || ""
|
||||
const messageId = payload.message_id as string
|
||||
const hasKind = hasAssistantKindPayload(payload)
|
||||
const kind = parseAssistantMessageKind(payload)
|
||||
if (!messageId) {
|
||||
break
|
||||
}
|
||||
|
||||
updateChatStore((prev) => ({
|
||||
messages: prev.messages.map((msg) =>
|
||||
msg.id === messageId ? { ...msg, content } : msg,
|
||||
msg.id === messageId
|
||||
? {
|
||||
...msg,
|
||||
content,
|
||||
...(hasKind ? { kind } : {}),
|
||||
}
|
||||
: msg,
|
||||
),
|
||||
}))
|
||||
break
|
||||
|
||||
@@ -47,6 +47,7 @@
|
||||
"step3": "Preparing response...",
|
||||
"step4": "Almost there..."
|
||||
},
|
||||
"reasoningLabel": "Reasoning",
|
||||
"history": "History",
|
||||
"noHistory": "No chat history yet",
|
||||
"historyLoadFailed": "Failed to load chat history",
|
||||
|
||||
@@ -47,6 +47,7 @@
|
||||
"step3": "准备回复...",
|
||||
"step4": "马上就好..."
|
||||
},
|
||||
"reasoningLabel": "思考",
|
||||
"history": "历史记录",
|
||||
"noHistory": "暂无对话历史",
|
||||
"historyLoadFailed": "加载历史记录失败",
|
||||
|
||||
@@ -11,11 +11,14 @@ export interface ChatAttachment {
|
||||
filename?: string
|
||||
}
|
||||
|
||||
export type AssistantMessageKind = "normal" | "thought"
|
||||
|
||||
export interface ChatMessage {
|
||||
id: string
|
||||
role: "user" | "assistant"
|
||||
content: string
|
||||
timestamp: number | string
|
||||
kind?: AssistantMessageKind
|
||||
attachments?: ChatAttachment[]
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user