From b7db0595441d1f725cd4ace994a9936e08557337 Mon Sep 17 00:00:00 2001 From: LC Date: Wed, 20 May 2026 13:42:21 +0800 Subject: [PATCH] feat(chat,seahorse): persist and display model_name across history (#2897) * feat(chat,seahorse): persist and display model_name across history * test(seahorse): fix lint regressions in repair coverage * fix(pico): preserve model_name in live updates * fix(pico): preserve model_name through live stream wrappers --- pkg/agent/agent.go | 6 + pkg/agent/agent_outbound.go | 43 +++- pkg/agent/agent_test.go | 2 +- pkg/agent/agent_utils.go | 52 ++++- pkg/agent/context_seahorse.go | 2 + pkg/agent/context_seahorse_test.go | 8 + pkg/agent/model_resolution.go | 18 +- pkg/agent/model_resolution_test.go | 49 +++++ pkg/agent/pipeline_execute.go | 12 +- pkg/agent/pipeline_finalize.go | 25 +-- pkg/agent/pipeline_llm.go | 14 +- pkg/agent/pipeline_setup.go | 6 + pkg/agent/pipeline_streaming.go | 17 +- pkg/agent/pipeline_streaming_test.go | 3 + pkg/agent/turn_coord_test.go | 180 +++++++++++++++ pkg/agent/turn_state.go | 2 + pkg/channels/interfaces.go | 11 + pkg/channels/manager.go | 65 +++++- pkg/channels/manager_test.go | 102 +++++++++ pkg/channels/pico/pico.go | 100 ++++++++- pkg/channels/pico/pico_test.go | 94 +++++++- pkg/channels/pico/protocol.go | 1 + pkg/memory/jsonl_test.go | 26 +++ pkg/providers/fallback.go | 17 +- pkg/providers/protocoltypes/types.go | 1 + pkg/seahorse/schema.go | 19 ++ pkg/seahorse/schema_test.go | 31 +++ pkg/seahorse/short_engine.go | 117 ++++++++-- pkg/seahorse/short_engine_test.go | 206 +++++++++++++++--- pkg/seahorse/store.go | 67 ++++-- pkg/seahorse/store_test.go | 14 ++ pkg/seahorse/types.go | 2 + pkg/session/jsonl_backend_test.go | 19 ++ web/backend/api/session.go | 22 +- web/backend/api/session_test.go | 19 +- web/frontend/src/api/sessions.ts | 1 + .../src/components/chat/assistant-message.tsx | 14 +- .../src/components/chat/chat-page.tsx | 1 + web/frontend/src/features/chat/history.ts | 3 +- web/frontend/src/features/chat/protocol.ts | 13 ++ web/frontend/src/store/chat.ts | 1 + 41 files changed, 1266 insertions(+), 139 deletions(-) diff --git a/pkg/agent/agent.go b/pkg/agent/agent.go index 49cd5e767..0392829f9 100644 --- a/pkg/agent/agent.go +++ b/pkg/agent/agent.go @@ -600,6 +600,12 @@ func (al *AgentLoop) runAgentLoop( Content: result.finalContent, ContextUsage: computeContextUsage(agent, opts.Dispatch.SessionKey), } + if modelName := strings.TrimSpace(result.modelName); modelName != "" { + if msg.Context.Raw == nil { + msg.Context.Raw = make(map[string]string, 1) + } + msg.Context.Raw["model_name"] = modelName + } markFinalOutbound(&msg) al.bus.PublishOutbound(ctx, msg) } diff --git a/pkg/agent/agent_outbound.go b/pkg/agent/agent_outbound.go index cae6dcb48..f5de6fa41 100644 --- a/pkg/agent/agent_outbound.go +++ b/pkg/agent/agent_outbound.go @@ -102,7 +102,10 @@ func (al *AgentLoop) targetReasoningChannelID(channelName string) (chatID string return "" } -func (al *AgentLoop) publishPicoReasoning(ctx context.Context, reasoningContent, chatID, sessionKey string) { +func (al *AgentLoop) publishPicoReasoning( + ctx context.Context, + reasoningContent, chatID, sessionKey, modelName string, +) { if reasoningContent == "" || chatID == "" { return } @@ -114,13 +117,16 @@ func (al *AgentLoop) publishPicoReasoning(ctx context.Context, reasoningContent, pubCtx, pubCancel := context.WithTimeout(ctx, 5*time.Second) defer pubCancel() + raw := map[string]string{metadataKeyMessageKind: messageKindThought} + if trimmedModelName := strings.TrimSpace(modelName); trimmedModelName != "" { + raw["model_name"] = trimmedModelName + } + if err := al.bus.PublishOutbound(pubCtx, bus.OutboundMessage{ Context: bus.InboundContext{ Channel: "pico", ChatID: chatID, - Raw: map[string]string{ - metadataKeyMessageKind: messageKindThought, - }, + Raw: raw, }, SessionKey: sessionKey, Content: reasoningContent, @@ -143,6 +149,7 @@ func (al *AgentLoop) publishPicoReasoning(ctx context.Context, reasoningContent, func (al *AgentLoop) publishPicoToolCallInterim( ctx context.Context, ts *turnState, + modelName string, reasoningContent string, content string, toolCalls []providers.ToolCall, @@ -155,7 +162,14 @@ func (al *AgentLoop) publishPicoToolCallInterim( pubCtx, pubCancel := context.WithTimeout(ctx, 3*time.Second) err := al.bus.PublishOutbound( pubCtx, - outboundMessageForTurnWithKind(ts, reasoningContent, messageKindThought), + outboundMessageForTurnWithOptions( + ts, + reasoningContent, + outboundTurnMessageOptions{ + kind: messageKindThought, + modelName: modelName, + }, + ), ) pubCancel() if err != nil && !errors.Is(err, context.DeadlineExceeded) && @@ -182,7 +196,12 @@ func (al *AgentLoop) publishPicoToolCallInterim( if strings.TrimSpace(content) != "" && !duplicateToolCallContent { pubCtx, pubCancel := context.WithTimeout(ctx, 3*time.Second) - err := al.bus.PublishOutbound(pubCtx, outboundMessageForTurn(ts, content)) + err := al.bus.PublishOutbound( + pubCtx, + outboundMessageForTurnWithOptions(ts, content, outboundTurnMessageOptions{ + modelName: modelName, + }), + ) pubCancel() if err != nil && !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, context.Canceled) && @@ -209,11 +228,13 @@ func (al *AgentLoop) publishPicoToolCallInterim( return } - msg := outboundMessageForTurnWithKind(ts, "", messageKindToolCalls) - if msg.Context.Raw == nil { - msg.Context.Raw = map[string]string{} - } - msg.Context.Raw[metadataKeyToolCalls] = string(rawToolCalls) + msg := outboundMessageForTurnWithOptions(ts, "", outboundTurnMessageOptions{ + kind: messageKindToolCalls, + modelName: modelName, + raw: map[string]string{ + metadataKeyToolCalls: string(rawToolCalls), + }, + }) pubCtx, pubCancel := context.WithTimeout(ctx, 3*time.Second) err = al.bus.PublishOutbound(pubCtx, msg) diff --git a/pkg/agent/agent_test.go b/pkg/agent/agent_test.go index bfdb0d2a4..aaf3d1a88 100644 --- a/pkg/agent/agent_test.go +++ b/pkg/agent/agent_test.go @@ -312,7 +312,7 @@ func TestPublishPicoReasoningIncludesSessionKey(t *testing.T) { defer cleanup() _ = provider - al.publishPicoReasoning(context.Background(), "reasoning", "pico-chat", "session-1") + al.publishPicoReasoning(context.Background(), "reasoning", "pico-chat", "session-1", "") select { case outbound := <-msgBus.OutboundChan(): diff --git a/pkg/agent/agent_utils.go b/pkg/agent/agent_utils.go index b155740a9..432a9f24c 100644 --- a/pkg/agent/agent_utils.go +++ b/pkg/agent/agent_utils.go @@ -96,15 +96,46 @@ func markFinalOutbound(msg *bus.OutboundMessage) { msg.Context.Raw[metadataKeyOutboundKind] = outboundKindFinal } -func outboundMessageForTurnWithKind(ts *turnState, content, kind string) bus.OutboundMessage { +type outboundTurnMessageOptions struct { + kind string + modelName string + raw map[string]string +} + +func outboundMessageForTurnWithOptions( + ts *turnState, + content string, + opts outboundTurnMessageOptions, +) bus.OutboundMessage { msg := outboundMessageForTurn(ts, content) - if strings.TrimSpace(kind) == "" { + trimmedKind := strings.TrimSpace(opts.kind) + trimmedModelName := strings.TrimSpace(opts.modelName) + rawCount := len(opts.raw) + if trimmedKind != "" { + rawCount++ + } + if trimmedModelName != "" { + rawCount++ + } + if rawCount == 0 { return msg } + if msg.Context.Raw == nil { - msg.Context.Raw = make(map[string]string, 1) + msg.Context.Raw = make(map[string]string, rawCount) + } + if trimmedKind != "" { + msg.Context.Raw[metadataKeyMessageKind] = trimmedKind + } + if trimmedModelName != "" { + msg.Context.Raw["model_name"] = trimmedModelName + } + for key, value := range opts.raw { + if strings.TrimSpace(key) == "" { + continue + } + msg.Context.Raw[key] = value } - msg.Context.Raw[metadataKeyMessageKind] = kind return msg } @@ -521,8 +552,9 @@ func hasMediaRefs(messages []providers.Message) bool { func sideQuestionModelName(agent *AgentInstance, usedLight bool) string { if usedLight && len(agent.LightCandidates) > 0 { - // Use the first light candidate's model - return agent.LightCandidates[0].Model + if name := resolvedCandidateModelName(agent.LightCandidates, ""); name != "" { + return name + } } return agent.Model } @@ -538,6 +570,14 @@ func modelNameFromIdentityKey(identityKey string) string { return identityKey } +func modelAliasFromCandidateIdentityKey(identityKey string) string { + const prefix = "model_name:" + if !strings.HasPrefix(identityKey, prefix) { + return "" + } + return strings.TrimSpace(strings.TrimPrefix(identityKey, prefix)) +} + func closeProviderIfStateful(provider providers.LLMProvider) { if stateful, ok := provider.(providers.StatefulProvider); ok { stateful.Close() diff --git a/pkg/agent/context_seahorse.go b/pkg/agent/context_seahorse.go index c6e5b30ac..2a10d2457 100644 --- a/pkg/agent/context_seahorse.go +++ b/pkg/agent/context_seahorse.go @@ -197,6 +197,7 @@ func providerToSeahorseMessage(msg protocoltypes.Message) seahorse.Message { result := seahorse.Message{ Role: msg.Role, Content: msg.Content, + ModelName: msg.ModelName, ReasoningContent: msg.ReasoningContent, TokenCount: tokenizer.EstimateMessageTokens(msg), } @@ -243,6 +244,7 @@ func seahorseToProviderMessages(result *seahorse.AssembleResult) []protocoltypes pm := protocoltypes.Message{ Role: msg.Role, Content: msg.Content, + ModelName: msg.ModelName, ReasoningContent: msg.ReasoningContent, } diff --git a/pkg/agent/context_seahorse_test.go b/pkg/agent/context_seahorse_test.go index e405ef944..2a9de3263 100644 --- a/pkg/agent/context_seahorse_test.go +++ b/pkg/agent/context_seahorse_test.go @@ -174,6 +174,7 @@ func TestProviderToSeahorseMessageWithReasoning(t *testing.T) { msg := protocoltypes.Message{ Role: "assistant", Content: "response text", + ModelName: "gpt-5.4-mini", ReasoningContent: "I thought about this carefully", } @@ -181,6 +182,9 @@ func TestProviderToSeahorseMessageWithReasoning(t *testing.T) { if result.ReasoningContent != "I thought about this carefully" { t.Errorf("ReasoningContent = %q, want 'I thought about this carefully'", result.ReasoningContent) } + if result.ModelName != "gpt-5.4-mini" { + t.Errorf("ModelName = %q, want %q", result.ModelName, "gpt-5.4-mini") + } } func TestSeahorseToProviderMessagesWithReasoning(t *testing.T) { @@ -189,6 +193,7 @@ func TestSeahorseToProviderMessagesWithReasoning(t *testing.T) { { Role: "assistant", Content: "response", + ModelName: "gpt-5.4", ReasoningContent: "thinking process", }, }, @@ -201,6 +206,9 @@ func TestSeahorseToProviderMessagesWithReasoning(t *testing.T) { if messages[0].ReasoningContent != "thinking process" { t.Errorf("ReasoningContent = %q, want 'thinking process'", messages[0].ReasoningContent) } + if messages[0].ModelName != "gpt-5.4" { + t.Errorf("ModelName = %q, want %q", messages[0].ModelName, "gpt-5.4") + } } func TestSeahorseToProviderMessages(t *testing.T) { diff --git a/pkg/agent/model_resolution.go b/pkg/agent/model_resolution.go index 724278e66..487a30722 100644 --- a/pkg/agent/model_resolution.go +++ b/pkg/agent/model_resolution.go @@ -75,6 +75,7 @@ func candidateFromModelConfig( return providers.FallbackCandidate{ Provider: protocol, Model: modelID, + DisplayName: strings.TrimSpace(mc.ModelName), RPM: mc.RPM, IdentityKey: modelConfigIdentityKey(mc), }, true @@ -147,8 +148,9 @@ func resolveModelCandidate( } return providers.FallbackCandidate{ - Provider: ref.Provider, - Model: ref.Model, + Provider: ref.Provider, + Model: ref.Model, + DisplayName: raw, }, true } @@ -197,6 +199,18 @@ func resolvedCandidateProvider(candidates []providers.FallbackCandidate, fallbac return fallback } +func resolvedCandidateModelName(candidates []providers.FallbackCandidate, fallback string) string { + if len(candidates) > 0 { + if name := modelAliasFromCandidateIdentityKey(candidates[0].IdentityKey); strings.TrimSpace(name) != "" { + return name + } + if displayName := strings.TrimSpace(candidates[0].DisplayName); displayName != "" { + return displayName + } + } + return strings.TrimSpace(fallback) +} + func resolvedModelConfig(cfg *config.Config, modelName, workspace string) (*config.ModelConfig, error) { if cfg == nil { return nil, fmt.Errorf("config is nil") diff --git a/pkg/agent/model_resolution_test.go b/pkg/agent/model_resolution_test.go index 270242eab..ea23e5dac 100644 --- a/pkg/agent/model_resolution_test.go +++ b/pkg/agent/model_resolution_test.go @@ -7,6 +7,55 @@ import ( "github.com/sipeed/picoclaw/pkg/providers" ) +func TestModelNameFromIdentityKey_LegacyProviderModel(t *testing.T) { + if got := modelNameFromIdentityKey("openai/gpt-5.4"); got != "gpt-5.4" { + t.Fatalf("modelNameFromIdentityKey() = %q, want %q", got, "gpt-5.4") + } +} + +func TestModelNameFromIdentityKey_PreservesNonLegacyIdentity(t *testing.T) { + if got := modelNameFromIdentityKey("model_name:primary"); got != "model_name:primary" { + t.Fatalf("modelNameFromIdentityKey() = %q, want %q", got, "model_name:primary") + } +} + +func TestModelAliasFromCandidateIdentityKey(t *testing.T) { + if got := modelAliasFromCandidateIdentityKey("model_name:primary"); got != "primary" { + t.Fatalf("modelAliasFromCandidateIdentityKey() = %q, want %q", got, "primary") + } + if got := modelAliasFromCandidateIdentityKey("openai/gpt-5.4"); got != "" { + t.Fatalf("modelAliasFromCandidateIdentityKey() = %q, want empty", got) + } +} + +func TestResolvedCandidateModelName_PrefersIdentityAlias(t *testing.T) { + got := resolvedCandidateModelName([]providers.FallbackCandidate{ + {Provider: "openai", Model: "gpt-5.4", IdentityKey: "model_name:primary"}, + }, "fallback-model") + if got != "primary" { + t.Fatalf("resolvedCandidateModelName() = %q, want %q", got, "primary") + } +} + +func TestResolvedCandidateModelName_DoesNotScanFallbackAliases(t *testing.T) { + got := resolvedCandidateModelName([]providers.FallbackCandidate{ + {Provider: "openai", Model: "gpt-5.4"}, + {Provider: "openai", Model: "gpt-5.4-mini", IdentityKey: "model_name:fallback"}, + }, "primary-model") + if got != "primary-model" { + t.Fatalf("resolvedCandidateModelName() = %q, want %q", got, "primary-model") + } +} + +func TestResolvedCandidateModelName_UsesCandidateDisplayName(t *testing.T) { + got := resolvedCandidateModelName([]providers.FallbackCandidate{ + {Provider: "openai", Model: "gpt-5.4", DisplayName: "gpt-5.4-display"}, + }, "fallback-model") + if got != "gpt-5.4-display" { + t.Fatalf("resolvedCandidateModelName() = %q, want %q", got, "gpt-5.4-display") + } +} + func TestResolveActiveModelConfig_PrefersCandidateIdentityKey(t *testing.T) { cfg := &config.Config{ ModelList: []*config.ModelConfig{ diff --git a/pkg/agent/pipeline_execute.go b/pkg/agent/pipeline_execute.go index 567e56d17..7281a0ead 100644 --- a/pkg/agent/pipeline_execute.go +++ b/pkg/agent/pipeline_execute.go @@ -180,7 +180,11 @@ toolLoop: toolFeedbackArgsPreview(toolArgs, toolFeedbackMaxLen), ) fbCtx, fbCancel := context.WithTimeout(turnCtx, 3*time.Second) - _ = al.bus.PublishOutbound(fbCtx, outboundMessageForTurnWithKind(ts, feedbackMsg, messageKindToolFeedback)) + _ = al.bus.PublishOutbound(fbCtx, outboundMessageForTurnWithOptions( + ts, + feedbackMsg, + outboundTurnMessageOptions{kind: messageKindToolFeedback}, + )) fbCancel() } @@ -467,7 +471,11 @@ toolLoop: toolFeedbackArgsPreview(toolArgs, toolFeedbackMaxLen), ) fbCtx, fbCancel := context.WithTimeout(turnCtx, 3*time.Second) - _ = al.bus.PublishOutbound(fbCtx, outboundMessageForTurnWithKind(ts, feedbackMsg, messageKindToolFeedback)) + _ = al.bus.PublishOutbound(fbCtx, outboundMessageForTurnWithOptions( + ts, + feedbackMsg, + outboundTurnMessageOptions{kind: messageKindToolFeedback}, + )) fbCancel() } diff --git a/pkg/agent/pipeline_finalize.go b/pkg/agent/pipeline_finalize.go index d15e312e5..1b674d0aa 100644 --- a/pkg/agent/pipeline_finalize.go +++ b/pkg/agent/pipeline_finalize.go @@ -33,6 +33,7 @@ func (p *Pipeline) Finalize( ts.setPhase(TurnPhaseCompleted) return turnResult{ finalContent: finalContent, + modelName: exec.llmModelName, status: turnStatus, followUps: append([]bus.InboundMessage(nil), ts.followUps...), }, nil @@ -44,6 +45,7 @@ func (p *Pipeline) Finalize( finalMsg := providers.Message{ Role: "assistant", Content: finalContent, + ModelName: exec.llmModelName, ReasoningContent: responseReasoningContent(exec.response), } ts.agent.Sessions.AddFullMessage(ts.sessionKey, finalMsg) @@ -80,24 +82,10 @@ func (p *Pipeline) Finalize( // so the final answer is still delivered outside normal SendResponse. if ((streamErr != nil && !isConfiguredStreamingVisibleError(streamErr)) || exec.streamingFallback) && !ts.opts.SendResponse && ts.opts.AllowInterimPicoPublish && finalContent != "" { - agentID, sessionKey, scope := outboundTurnMetadata( - ts.agent.ID, - ts.opts.Dispatch.SessionKey, - ts.opts.Dispatch.SessionScope, - ) - msg := bus.OutboundMessage{ - Context: outboundContextFromInbound( - ts.opts.Dispatch.InboundContext, - ts.opts.Dispatch.Channel(), - ts.opts.Dispatch.ChatID(), - ts.opts.Dispatch.ReplyToMessageID(), - ), - AgentID: agentID, - SessionKey: sessionKey, - Scope: scope, - Content: finalContent, - ContextUsage: contextUsage, - } + msg := outboundMessageForTurnWithOptions(ts, finalContent, outboundTurnMessageOptions{ + modelName: exec.llmModelName, + }) + msg.ContextUsage = contextUsage markFinalOutbound(&msg) _ = al.bus.PublishOutbound(turnCtx, msg) } @@ -112,6 +100,7 @@ func (p *Pipeline) Finalize( ts.setPhase(TurnPhaseCompleted) return turnResult{ finalContent: finalContent, + modelName: exec.llmModelName, status: turnStatus, followUps: append([]bus.InboundMessage(nil), ts.followUps...), }, nil diff --git a/pkg/agent/pipeline_llm.go b/pkg/agent/pipeline_llm.go index 2410b1b54..aaae765ef 100644 --- a/pkg/agent/pipeline_llm.go +++ b/pkg/agent/pipeline_llm.go @@ -200,6 +200,16 @@ func (p *Pipeline) CallLLM( map[string]any{"agent_id": ts.agent.ID, "iteration": iteration}, ) } + for _, candidate := range exec.activeCandidates { + if candidate.StableKey() != fbResult.IdentityKey { + continue + } + exec.llmModelName = resolvedCandidateModelName( + []providers.FallbackCandidate{candidate}, + exec.llmModelName, + ) + break + } return fbResult.Response, nil } return exec.activeProvider.Chat(providerCtx, messagesForCall, toolDefsForCall, exec.llmModel, exec.llmOpts) @@ -477,7 +487,7 @@ func (p *Pipeline) CallLLM( // Publish pico thoughts before the turn context is canceled at return time. // The async variant can race with turn teardown and intermittently drop the // thought message in CI even though the LLM produced reasoning content. - al.publishPicoReasoning(turnCtx, reasoningContent, ts.chatID, ts.sessionKey) + al.publishPicoReasoning(turnCtx, reasoningContent, ts.chatID, ts.sessionKey, exec.llmModelName) } } else { go al.handleReasoning( @@ -564,6 +574,7 @@ func (p *Pipeline) CallLLM( assistantMsg := providers.Message{ Role: "assistant", Content: exec.response.Content, + ModelName: exec.llmModelName, ReasoningContent: reasoningContent, } for _, tc := range exec.normalizedToolCalls { @@ -607,6 +618,7 @@ func (p *Pipeline) CallLLM( al.publishPicoToolCallInterim( turnCtx, ts, + exec.llmModelName, reasoningContent, exec.response.Content, assistantMsg.ToolCalls, diff --git a/pkg/agent/pipeline_setup.go b/pkg/agent/pipeline_setup.go index ba6fa06fb..1fdd8f6c4 100644 --- a/pkg/agent/pipeline_setup.go +++ b/pkg/agent/pipeline_setup.go @@ -89,6 +89,11 @@ func (p *Pipeline) SetupTurn(ctx context.Context, ts *turnState) (*turnExecution if usedLight && ts.agent.LightProvider != nil { activeProvider = ts.agent.LightProvider } + activeModelName := strings.TrimSpace(ts.agent.Model) + if usedLight { + activeModelName = strings.TrimSpace(sideQuestionModelName(ts.agent, true)) + } + activeModelName = resolvedCandidateModelName(activeCandidates, activeModelName) exec := newTurnExecution( ts.agent, @@ -106,6 +111,7 @@ func (p *Pipeline) SetupTurn(ctx context.Context, ts *turnState) (*turnExecution activeModel, p.Cfg.Agents.Defaults.Provider, ) + exec.llmModelName = activeModelName exec.activeProvider = activeProvider exec.usedLight = usedLight diff --git a/pkg/agent/pipeline_streaming.go b/pkg/agent/pipeline_streaming.go index 5a9bbfd37..d0ef90524 100644 --- a/pkg/agent/pipeline_streaming.go +++ b/pkg/agent/pipeline_streaming.go @@ -50,9 +50,10 @@ func (p *Pipeline) tryConfiguredStreamingLLM( } publisher := &streamingChunkPublisher{ - streamer: streamer, - channel: ts.channel, - chatID: ts.chatID, + streamer: streamer, + channel: ts.channel, + chatID: ts.chatID, + modelName: exec.llmModelName, } logger.DebugCF("agent", "configured streaming enabled", map[string]any{ @@ -371,6 +372,7 @@ type streamingChunkPublisher struct { streamer bus.Streamer channel string chatID string + modelName string published bool reasoningPublished bool err error @@ -380,6 +382,9 @@ func (p *streamingChunkPublisher) Update(ctx context.Context, accumulated string if p == nil || p.streamer == nil || strings.TrimSpace(accumulated) == "" { return } + if setter, ok := p.streamer.(interface{ SetModelName(modelName string) }); ok { + setter.SetModelName(p.modelName) + } if err := p.streamer.Update(ctx, accumulated); err != nil { p.err = err logger.WarnCF("agent", "stream update failed", map[string]any{ @@ -396,6 +401,9 @@ func (p *streamingChunkPublisher) UpdateReasoning(ctx context.Context, accumulat if p == nil || p.streamer == nil || strings.TrimSpace(accumulated) == "" { return } + if setter, ok := p.streamer.(interface{ SetModelName(modelName string) }); ok { + setter.SetModelName(p.modelName) + } reasoningStreamer, ok := p.streamer.(bus.ReasoningStreamer) if !ok { return @@ -434,6 +442,9 @@ func (p *streamingChunkPublisher) Finalize(ctx context.Context, content string, if strings.TrimSpace(content) == "" && !p.published { return nil } + if setter, ok := p.streamer.(interface{ SetModelName(modelName string) }); ok { + setter.SetModelName(p.modelName) + } var err error if streamer, ok := p.streamer.(bus.ContextUsageStreamer); ok { err = streamer.FinalizeWithContext(ctx, content, contextUsage) diff --git a/pkg/agent/pipeline_streaming_test.go b/pkg/agent/pipeline_streaming_test.go index 98122f34a..a18e4c6f1 100644 --- a/pkg/agent/pipeline_streaming_test.go +++ b/pkg/agent/pipeline_streaming_test.go @@ -570,6 +570,9 @@ func TestConfiguredStreamingFinalFlushFailureBeforeVisibleOutputPublishesFallbac if outbound.Content != "stream response" { t.Fatalf("fallback outbound content = %q, want stream response", outbound.Content) } + if got := outbound.Context.Raw["model_name"]; got != "test-model" { + t.Fatalf("fallback outbound model_name = %q, want %q", got, "test-model") + } case <-time.After(time.Second): t.Fatal("expected fallback outbound after invisible final stream flush failure") } diff --git a/pkg/agent/turn_coord_test.go b/pkg/agent/turn_coord_test.go index c7cdd8a32..4a63e639b 100644 --- a/pkg/agent/turn_coord_test.go +++ b/pkg/agent/turn_coord_test.go @@ -10,6 +10,7 @@ import ( "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/providers" + "github.com/sipeed/picoclaw/pkg/routing" "github.com/sipeed/picoclaw/pkg/session" ) @@ -37,6 +38,39 @@ func (p *simpleConvProvider) GetDefaultModel() string { return "simple-model" } +type sequenceProvider struct { + responses []*providers.LLMResponse + errors []error + callCount int + mu sync.Mutex +} + +func (p *sequenceProvider) Chat( + ctx context.Context, + messages []providers.Message, + tools []providers.ToolDefinition, + model string, + opts map[string]any, +) (*providers.LLMResponse, error) { + p.mu.Lock() + defer p.mu.Unlock() + + idx := p.callCount + p.callCount++ + + if idx < len(p.errors) && p.errors[idx] != nil { + return nil, p.errors[idx] + } + if idx < len(p.responses) && p.responses[idx] != nil { + return p.responses[idx], nil + } + return &providers.LLMResponse{Content: "ok", FinishReason: "stop"}, nil +} + +func (p *sequenceProvider) GetDefaultModel() string { + return "sequence-model" +} + type nativeSearchCaptureProvider struct { lastOpts map[string]any } @@ -271,6 +305,152 @@ func TestPipeline_CallLLM_SimpleResponse(t *testing.T) { } } +func TestPipeline_SetupTurn_ModelNameDoesNotUseFallbackAliasBeforeFallback(t *testing.T) { + al, agent, cleanup := newTurnCoordTestLoop(t, &simpleConvProvider{}) + defer cleanup() + + agent.Model = "primary-model" + agent.Candidates = []providers.FallbackCandidate{ + {Provider: "openai", Model: "gpt-5.4"}, + {Provider: "anthropic", Model: "claude-sonnet", IdentityKey: "model_name:fallback-model"}, + } + + pipeline := NewPipeline(al) + ts := newTurnState(agent, makeTestProcessOpts("test-session"), turnEventScope{ + turnID: "turn-1", + context: newTurnContext(nil, nil, nil), + }) + + exec, err := pipeline.SetupTurn(context.Background(), ts) + if err != nil { + t.Fatalf("SetupTurn failed: %v", err) + } + if exec.llmModelName != "primary-model" { + t.Fatalf("exec.llmModelName = %q, want %q", exec.llmModelName, "primary-model") + } +} + +func TestPipeline_CallLLM_UsesSuccessfulFallbackIdentityAlias(t *testing.T) { + provider := &sequenceProvider{ + errors: []error{ + errors.New("status: 429 - rate limit exceeded"), + nil, + }, + responses: []*providers.LLMResponse{ + nil, + {Content: "fallback answer", FinishReason: "stop"}, + }, + } + al, agent, cleanup := newTurnCoordTestLoop(t, provider) + defer cleanup() + + agent.Model = "primary-model" + agent.Candidates = []providers.FallbackCandidate{ + {Provider: "openai", Model: "gpt-5.4", IdentityKey: "model_name:primary"}, + {Provider: "openai", Model: "gpt-5.4", IdentityKey: "model_name:secondary"}, + } + al.fallback = providers.NewFallbackChain(providers.NewCooldownTracker(), nil) + + pipeline := NewPipeline(al) + ts := newTurnState(agent, makeTestProcessOpts("test-session"), turnEventScope{ + turnID: "turn-1", + context: newTurnContext(nil, nil, nil), + }) + + exec, err := pipeline.SetupTurn(context.Background(), ts) + if err != nil { + t.Fatalf("SetupTurn failed: %v", err) + } + + ctrl, err := pipeline.CallLLM(context.Background(), context.Background(), ts, exec, 1) + if err != nil { + t.Fatalf("CallLLM failed: %v", err) + } + if ctrl != ControlBreak { + t.Fatalf("expected ControlBreak, got %v", ctrl) + } + if exec.llmModelName != "secondary" { + t.Fatalf("exec.llmModelName = %q, want %q", exec.llmModelName, "secondary") + } +} + +func TestPipeline_CallLLM_UsesSuccessfulFallbackDisplayNameWithoutAlias(t *testing.T) { + provider := &sequenceProvider{ + errors: []error{ + errors.New("status: 429 - rate limit exceeded"), + nil, + }, + responses: []*providers.LLMResponse{ + nil, + {Content: "fallback answer", FinishReason: "stop"}, + }, + } + al, agent, cleanup := newTurnCoordTestLoop(t, provider) + defer cleanup() + + agent.Model = "primary-model" + agent.Candidates = []providers.FallbackCandidate{ + {Provider: "openai", Model: "gpt-5.4", IdentityKey: "model_name:primary", DisplayName: "primary-model"}, + {Provider: "anthropic", Model: "claude-sonnet", DisplayName: "anthropic/claude-sonnet"}, + } + al.fallback = providers.NewFallbackChain(providers.NewCooldownTracker(), nil) + + pipeline := NewPipeline(al) + ts := newTurnState(agent, makeTestProcessOpts("test-session"), turnEventScope{ + turnID: "turn-1", + context: newTurnContext(nil, nil, nil), + }) + + exec, err := pipeline.SetupTurn(context.Background(), ts) + if err != nil { + t.Fatalf("SetupTurn failed: %v", err) + } + + ctrl, err := pipeline.CallLLM(context.Background(), context.Background(), ts, exec, 1) + if err != nil { + t.Fatalf("CallLLM failed: %v", err) + } + if ctrl != ControlBreak { + t.Fatalf("expected ControlBreak, got %v", ctrl) + } + if exec.llmModelName != "anthropic/claude-sonnet" { + t.Fatalf("exec.llmModelName = %q, want %q", exec.llmModelName, "anthropic/claude-sonnet") + } +} + +func TestPipeline_SetupTurn_UsesLightCandidateDisplayName(t *testing.T) { + al, agent, cleanup := newTurnCoordTestLoop(t, &simpleConvProvider{}) + defer cleanup() + + agent.Model = "primary-model" + agent.Candidates = []providers.FallbackCandidate{ + {Provider: "openai", Model: "gpt-5.4", IdentityKey: "model_name:primary", DisplayName: "primary-model"}, + } + agent.LightCandidates = []providers.FallbackCandidate{ + {Provider: "openai", Model: "gpt-5.4-mini", IdentityKey: "model_name:light-model", DisplayName: "light-model"}, + } + agent.Router = routing.New(routing.RouterConfig{LightModel: "light-model", Threshold: 1}) + + pipeline := NewPipeline(al) + opts := makeTestProcessOpts("test-session") + opts.UserMessage = "" + ts := newTurnState(agent, opts, turnEventScope{ + turnID: "turn-1", + context: newTurnContext(nil, nil, nil), + }) + + exec, err := pipeline.SetupTurn(context.Background(), ts) + if err != nil { + t.Fatalf("SetupTurn failed: %v", err) + } + if !exec.usedLight { + t.Fatal("expected light routing to be used") + } + if exec.llmModelName != "light-model" { + t.Fatalf("exec.llmModelName = %q, want %q", exec.llmModelName, "light-model") + } +} + func TestRunTurn_FinalizeSaveErrorEmitsErrorTurnEnd(t *testing.T) { al, agent, cleanup := newTurnCoordTestLoop(t, &simpleConvProvider{}) defer cleanup() diff --git a/pkg/agent/turn_state.go b/pkg/agent/turn_state.go index ad5d59723..ddd1eb894 100644 --- a/pkg/agent/turn_state.go +++ b/pkg/agent/turn_state.go @@ -84,6 +84,7 @@ const ( type turnResult struct { finalContent string + modelName string status TurnEndStatus followUps []bus.InboundMessage } @@ -140,6 +141,7 @@ type turnExecution struct { callMessages []providers.Message providerToolDefs []providers.ToolDefinition llmModel string + llmModelName string llmOpts map[string]any gracefulTerminal bool useNativeSearch bool diff --git a/pkg/channels/interfaces.go b/pkg/channels/interfaces.go index 0cfd435b0..ad4421959 100644 --- a/pkg/channels/interfaces.go +++ b/pkg/channels/interfaces.go @@ -20,6 +20,17 @@ type MessageEditor interface { EditMessage(ctx context.Context, chatID string, messageID string, content string) error } +// MessageEditorWithPayload extends MessageEditor for channels that can update +// structured message metadata in addition to plain text content. +type MessageEditorWithPayload interface { + EditMessageWithPayload( + ctx context.Context, + chatID string, + messageID string, + payload map[string]any, + ) error +} + // MessageDeleter — channels that can delete a message by ID. type MessageDeleter interface { DeleteMessage(ctx context.Context, chatID string, messageID string) error diff --git a/pkg/channels/manager.go b/pkg/channels/manager.go index 43423148c..03f65e11a 100644 --- a/pkg/channels/manager.go +++ b/pkg/channels/manager.go @@ -191,6 +191,19 @@ func outboundMessageBypassesPlaceholderEdit(msg bus.OutboundMessage) bool { return strings.EqualFold(kind, "thought") || strings.EqualFold(kind, "tool_calls") } +func outboundMessageEditPayload(msg bus.OutboundMessage, content string) map[string]any { + payload := map[string]any{ + "content": content, + } + if len(msg.Context.Raw) == 0 { + return payload + } + if modelName := strings.TrimSpace(msg.Context.Raw["model_name"]); modelName != "" { + payload["model_name"] = modelName + } + return payload +} + func outboundMediaChannel(msg bus.OutboundMediaMessage) string { return msg.Context.Channel } @@ -394,7 +407,16 @@ func (m *Manager) preSend(ctx context.Context, name string, msg bus.OutboundMess if deleter, ok := ch.(MessageDeleter); ok { deleter.DeleteMessage(ctx, chatID, entry.id) // best effort } else if editor, ok := ch.(MessageEditor); ok { - editor.EditMessage(ctx, chatID, entry.id, msg.Content) // fallback + if payloadEditor, ok := ch.(MessageEditorWithPayload); ok { + _ = payloadEditor.EditMessageWithPayload( + ctx, + chatID, + entry.id, + outboundMessageEditPayload(msg, msg.Content), + ) + } else { + editor.EditMessage(ctx, chatID, entry.id, msg.Content) // fallback + } } } } @@ -446,7 +468,18 @@ func (m *Manager) preSend(ctx context.Context, name string, msg bus.OutboundMess trackedContent = prepareToolFeedbackMessageContent(ch, msg.Content) content = InitialAnimatedToolFeedbackContent(trackedContent) } - if err := editor.EditMessage(ctx, chatID, entry.id, content); err == nil { + err := func() error { + if payloadEditor, ok := ch.(MessageEditorWithPayload); ok { + return payloadEditor.EditMessageWithPayload( + ctx, + chatID, + entry.id, + outboundMessageEditPayload(msg, content), + ) + } + return editor.EditMessage(ctx, chatID, entry.id, content) + }() + if err == nil { trackedChatID := trackedToolFeedbackMessageChatID(ch, chatID, &msg.Context) if tracker, ok := ch.(toolFeedbackMessageTracker); ok && isToolFeedback { tracker.RecordToolFeedbackMessage(trackedChatID, entry.id, trackedContent) @@ -643,6 +676,18 @@ func reasoningStreamerFrom(streamer bus.Streamer) bus.ReasoningStreamer { return nil } +type modelNameStreamer interface { + SetModelName(modelName string) +} + +func setStreamerModelName(streamer any, modelName string) { + setter, ok := streamer.(modelNameStreamer) + if !ok { + return + } + setter.SetModelName(modelName) +} + // splitMarkerStreamer turns accumulated streaming text containing // MessageSplitMarker into separate channel stream messages. type splitMarkerStreamer struct { @@ -654,6 +699,7 @@ type splitMarkerStreamer struct { finalized bool onFinalize func(context.Context, string) clearMarker func() + modelName string } func (s *splitMarkerStreamer) Update(ctx context.Context, content string) error { @@ -682,6 +728,7 @@ func (s *splitMarkerStreamer) UpdateReasoning(ctx context.Context, content strin if s.reasoning == nil { return nil } + setStreamerModelName(s.reasoning, s.modelName) return s.reasoning.UpdateReasoning(ctx, content) } @@ -691,9 +738,18 @@ func (s *splitMarkerStreamer) FinalizeReasoning(ctx context.Context, content str if s.reasoning == nil { return nil } + setStreamerModelName(s.reasoning, s.modelName) return s.reasoning.FinalizeReasoning(ctx, content) } +func (s *splitMarkerStreamer) SetModelName(modelName string) { + s.mu.Lock() + defer s.mu.Unlock() + s.modelName = strings.TrimSpace(modelName) + setStreamerModelName(s.current, s.modelName) + setStreamerModelName(s.reasoning, s.modelName) +} + func (s *splitMarkerStreamer) Cancel(ctx context.Context) { s.mu.Lock() defer s.mu.Unlock() @@ -772,6 +828,7 @@ func (s *splitMarkerStreamer) ensureCurrentLocked(ctx context.Context) error { return err } s.current = streamer + setStreamerModelName(s.current, s.modelName) return nil } @@ -856,6 +913,10 @@ func (s *finalizeHookStreamer) FinalizeReasoning(ctx context.Context, content st return nil } +func (s *finalizeHookStreamer) SetModelName(modelName string) { + setStreamerModelName(s.Streamer, strings.TrimSpace(modelName)) +} + func (s *finalizeHookStreamer) runFinalizeHook(ctx context.Context, content string) { if s.onFinalize != nil { s.onFinalize(ctx, content) diff --git a/pkg/channels/manager_test.go b/pkg/channels/manager_test.go index af3bc6c05..a7ceac67d 100644 --- a/pkg/channels/manager_test.go +++ b/pkg/channels/manager_test.go @@ -142,11 +142,21 @@ func (m *mockReasoningStreamer) FinalizeReasoning(_ context.Context, content str return nil } +type modelTrackingReasoningStreamer struct { + mockReasoningStreamer + modelNames []string +} + +func (m *modelTrackingReasoningStreamer) SetModelName(modelName string) { + m.modelNames = append(m.modelNames, strings.TrimSpace(modelName)) +} + type recordingStreamSegment struct { updates []string finals []string finalUsage *bus.ContextUsage canceledCount int + modelNames []string } func (s *recordingStreamSegment) Update(_ context.Context, content string) error { @@ -168,6 +178,10 @@ func (s *recordingStreamSegment) Cancel(context.Context) { s.canceledCount++ } +func (s *recordingStreamSegment) SetModelName(modelName string) { + s.modelNames = append(s.modelNames, strings.TrimSpace(modelName)) +} + type mockStreamingChannel struct { mockMessageEditor streamer Streamer @@ -2068,6 +2082,42 @@ func TestGetStreamer_PreservesReasoningStreamer(t *testing.T) { } } +func TestGetStreamer_PreservesModelNameSetter(t *testing.T) { + m := newTestManager() + inner := &modelTrackingReasoningStreamer{} + ch := &mockStreamingChannel{ + streamer: inner, + } + m.channels["test"] = ch + + streamer, ok := m.GetStreamer(context.Background(), "test", "123", "") + if !ok { + t.Fatal("expected streamer to be available") + } + setter, ok := streamer.(interface{ SetModelName(modelName string) }) + if !ok { + t.Fatal("manager-wrapped streamer should preserve SetModelName") + } + setter.SetModelName("gpt-5.4") + if err := streamer.Update(context.Background(), "hello"); err != nil { + t.Fatalf("Update() error = %v", err) + } + reasoningStreamer, ok := streamer.(bus.ReasoningStreamer) + if !ok { + t.Fatal("manager-wrapped streamer should preserve ReasoningStreamer") + } + setter.SetModelName("gpt-5.4") + if err := reasoningStreamer.UpdateReasoning(context.Background(), "thinking"); err != nil { + t.Fatalf("UpdateReasoning() error = %v", err) + } + if len(inner.modelNames) != 2 { + t.Fatalf("model name calls = %v, want 2 forwarded calls", inner.modelNames) + } + if inner.modelNames[0] != "gpt-5.4" || inner.modelNames[1] != "gpt-5.4" { + t.Fatalf("model name calls = %v, want both forwarded as gpt-5.4", inner.modelNames) + } +} + func TestGetStreamer_SplitOnMarkerStreamsSeparateSegments(t *testing.T) { m := newTestManager() m.config = &config.Config{ @@ -2188,6 +2238,58 @@ func TestGetStreamer_SplitOnMarkerKeepsReasoningOnInitialStreamer(t *testing.T) } } +func TestGetStreamer_SplitOnMarkerPreservesModelNameSetter(t *testing.T) { + m := newTestManager() + m.config = &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + SplitOnMarker: true, + }, + }, + } + + initial := &modelTrackingReasoningStreamer{} + next := &recordingStreamSegment{} + callCount := 0 + ch := &mockStreamingChannel{ + beginStreamFn: func(context.Context, string) (Streamer, error) { + callCount++ + if callCount == 1 { + return initial, nil + } + return next, nil + }, + } + m.channels["test"] = ch + + streamer, ok := m.GetStreamer(context.Background(), "test", "123", "") + if !ok { + t.Fatal("expected streamer to be available") + } + setter, ok := streamer.(interface{ SetModelName(modelName string) }) + if !ok { + t.Fatal("split streamer should preserve SetModelName") + } + setter.SetModelName("gpt-5.4-mini") + if err := streamer.Update(context.Background(), "hello<|[SPLIT]|>world"); err != nil { + t.Fatalf("Update() error = %v", err) + } + reasoningStreamer, ok := streamer.(bus.ReasoningStreamer) + if !ok { + t.Fatal("split streamer should preserve ReasoningStreamer") + } + if err := reasoningStreamer.UpdateReasoning(context.Background(), "thinking"); err != nil { + t.Fatalf("UpdateReasoning() error = %v", err) + } + + if len(initial.modelNames) == 0 || initial.modelNames[0] != "gpt-5.4-mini" { + t.Fatalf("initial model names = %v, want forwarded gpt-5.4-mini", initial.modelNames) + } + if len(next.modelNames) == 0 || next.modelNames[0] != "gpt-5.4-mini" { + t.Fatalf("next model names = %v, want forwarded gpt-5.4-mini", next.modelNames) + } +} + func TestGetStreamer_FinalizeSeparateMessagesClearsTrackedToolFeedback(t *testing.T) { m := newTestManager() m.config = &config.Config{ diff --git a/pkg/channels/pico/pico.go b/pkg/channels/pico/pico.go index e15ffa6c2..3c3938989 100644 --- a/pkg/channels/pico/pico.go +++ b/pkg/channels/pico/pico.go @@ -325,6 +325,9 @@ func (c *PicoChannel) Send(ctx context.Context, msg bus.OutboundMessage) ([]stri PayloadKeyContent: content, "message_id": msgID, } + if modelName := strings.TrimSpace(msg.Context.Raw[PayloadKeyModelName]); modelName != "" { + payload[PayloadKeyModelName] = modelName + } switch { case isThought: payload[PayloadKeyKind] = MessageKindThought @@ -359,6 +362,15 @@ func (c *PicoChannel) EditMessage(ctx context.Context, chatID string, messageID return c.editMessage(ctx, chatID, messageID, content, nil) } +func (c *PicoChannel) EditMessageWithPayload( + ctx context.Context, + chatID string, + messageID string, + payload map[string]any, +) error { + return c.editMessagePayload(ctx, chatID, messageID, payload, nil) +} + // DeleteMessage implements channels.MessageDeleter. func (c *PicoChannel) DeleteMessage(ctx context.Context, chatID string, messageID string) error { outMsg := newMessage(TypeMessageDelete, map[string]any{ @@ -419,14 +431,23 @@ func (c *PicoChannel) finalizeTrackedToolFeedbackMessage( ctx context.Context, chatID string, content string, - editFn func(context.Context, string, string, string, *bus.ContextUsage) error, + editFn func(context.Context, string, string, map[string]any, *bus.ContextUsage) error, + payload map[string]any, contextUsage *bus.ContextUsage, ) ([]string, bool) { msgID, baseContent, ok := c.takeToolFeedbackMessage(chatID) if !ok || editFn == nil { return nil, false } - if err := editFn(ctx, chatID, msgID, content, contextUsage); err != nil { + if payload == nil { + payload = map[string]any{ + PayloadKeyContent: content, + } + } + if _, ok := payload[PayloadKeyContent]; !ok { + payload[PayloadKeyContent] = content + } + if err := editFn(ctx, chatID, msgID, payload, contextUsage); err != nil { c.RecordToolFeedbackMessage(chatID, msgID, baseContent) return nil, false } @@ -437,7 +458,20 @@ func (c *PicoChannel) FinalizeToolFeedbackMessage(ctx context.Context, msg bus.O if !outboundMessageFinalizesTrackedToolFeedback(msg) { return nil, false } - return c.finalizeTrackedToolFeedbackMessage(ctx, msg.ChatID, msg.Content, c.editMessage, msg.ContextUsage) + payload := map[string]any{ + PayloadKeyContent: msg.Content, + } + if modelName := strings.TrimSpace(msg.Context.Raw[PayloadKeyModelName]); modelName != "" { + payload[PayloadKeyModelName] = modelName + } + return c.finalizeTrackedToolFeedbackMessage( + ctx, + msg.ChatID, + msg.Content, + c.editMessagePayload, + payload, + msg.ContextUsage, + ) } // StartTyping implements channels.TypingCapable. @@ -496,6 +530,7 @@ func (c *PicoChannel) BeginStream(ctx context.Context, chatID string) (channels. type picoStreamer struct { channel *PicoChannel chatID string + modelName string messageID string reasoningID string throttleInterval time.Duration @@ -509,6 +544,15 @@ type picoStreamer struct { mu sync.Mutex } +func (s *picoStreamer) SetModelName(modelName string) { + if s == nil { + return + } + s.mu.Lock() + defer s.mu.Unlock() + s.modelName = strings.TrimSpace(modelName) +} + func (s *picoStreamer) Update(ctx context.Context, content string) error { s.mu.Lock() defer s.mu.Unlock() @@ -613,13 +657,23 @@ func (s *picoStreamer) sendLocked(ctx context.Context, content string, contextUs PayloadKeyContent: content, "message_id": s.messageID, } + if s.modelName != "" { + payload[PayloadKeyModelName] = s.modelName + } setContextUsagePayload(payload, contextUsage) outMsg := newMessage(TypeMessageCreate, payload) if err := s.channel.broadcastToSession(s.chatID, outMsg); err != nil { return err } } else if content != s.lastContent || contextUsage != nil { - if err := s.channel.editMessage(ctx, s.chatID, s.messageID, content, contextUsage); err != nil { + payload := map[string]any{ + PayloadKeyContent: content, + "message_id": s.messageID, + } + if s.modelName != "" { + payload[PayloadKeyModelName] = s.modelName + } + if err := s.channel.editMessagePayload(ctx, s.chatID, s.messageID, payload, contextUsage); err != nil { return err } } @@ -642,6 +696,9 @@ func (s *picoStreamer) sendReasoningLocked(ctx context.Context, content string) PayloadKeyKind: MessageKindThought, PayloadKeyThought: true, } + if s.modelName != "" { + payload[PayloadKeyModelName] = s.modelName + } outMsg := newMessage(TypeMessageCreate, payload) if err := s.channel.broadcastToSession(s.chatID, outMsg); err != nil { return err @@ -653,6 +710,9 @@ func (s *picoStreamer) sendReasoningLocked(ctx context.Context, content string) PayloadKeyKind: MessageKindThought, PayloadKeyThought: true, } + if s.modelName != "" { + payload[PayloadKeyModelName] = s.modelName + } outMsg := newMessage(TypeMessageUpdate, payload) if err := s.channel.broadcastToSession(s.chatID, outMsg); err != nil { return err @@ -744,6 +804,9 @@ func (c *PicoChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessag "attachments": attachments, "message_id": msgID, }) + if modelName := strings.TrimSpace(msg.Context.Raw[PayloadKeyModelName]); modelName != "" { + outMsg.Payload[PayloadKeyModelName] = modelName + } if err := c.broadcastToSession(msg.ChatID, outMsg); err != nil { return nil, err @@ -1358,11 +1421,30 @@ func (c *PicoChannel) editMessage( content string, contextUsage *bus.ContextUsage, ) error { - payload := map[string]any{ - "message_id": messageID, - "content": content, + return c.editMessagePayload(ctx, chatID, messageID, map[string]any{ + PayloadKeyContent: content, + }, contextUsage) +} + +func (c *PicoChannel) editMessagePayload( + ctx context.Context, + chatID string, + messageID string, + payload map[string]any, + contextUsage *bus.ContextUsage, +) error { + if payload == nil { + payload = map[string]any{} } - setContextUsagePayload(payload, contextUsage) - outMsg := newMessage(TypeMessageUpdate, payload) + normalized := make(map[string]any, len(payload)+1) + for key, value := range payload { + normalized[key] = value + } + if _, ok := normalized[PayloadKeyContent]; !ok { + normalized[PayloadKeyContent] = "" + } + normalized["message_id"] = messageID + setContextUsagePayload(normalized, contextUsage) + outMsg := newMessage(TypeMessageUpdate, normalized) return c.broadcastToSession(chatID, outMsg) } diff --git a/pkg/channels/pico/pico_test.go b/pkg/channels/pico/pico_test.go index fb0959526..9e90fc539 100644 --- a/pkg/channels/pico/pico_test.go +++ b/pkg/channels/pico/pico_test.go @@ -46,12 +46,15 @@ func TestFinalizeTrackedToolFeedbackMessage_StopsTrackingBeforeEdit(t *testing.T context.Background(), "pico:chat-1", "final reply", - func(_ context.Context, chatID, messageID, content string, contextUsage *bus.ContextUsage) error { + func(_ context.Context, chatID, messageID string, payload map[string]any, contextUsage *bus.ContextUsage) error { if _, ok := ch.currentToolFeedbackMessage(chatID); ok { t.Fatal("expected tracked tool feedback to be stopped before edit") } - if chatID != "pico:chat-1" || messageID != "msg-1" || content != "final reply" { - t.Fatalf("unexpected edit args: %s %s %s", chatID, messageID, content) + if chatID != "pico:chat-1" || messageID != "msg-1" { + t.Fatalf("unexpected edit args: %s %s", chatID, messageID) + } + if got := payload[PayloadKeyContent]; got != "final reply" { + t.Fatalf("unexpected content payload: %#v", got) } if contextUsage != nil { t.Fatalf("unexpected context usage: %+v", contextUsage) @@ -59,6 +62,7 @@ func TestFinalizeTrackedToolFeedbackMessage_StopsTrackingBeforeEdit(t *testing.T return nil }, nil, + nil, ) if !handled { t.Fatal("expected finalizeTrackedToolFeedbackMessage to handle tracked message") @@ -115,7 +119,8 @@ func TestSend_ThoughtMessageDoesNotFinalizeTrackedToolFeedback(t *testing.T) { Channel: "pico", ChatID: "pico:sess-1", Raw: map[string]string{ - "message_kind": MessageKindThought, + "message_kind": MessageKindThought, + PayloadKeyModelName: "gpt-5.4-mini", }, }, }); err != nil { @@ -134,6 +139,9 @@ func TestSend_ThoughtMessageDoesNotFinalizeTrackedToolFeedback(t *testing.T) { if got := payload[PayloadKeyKind]; got != MessageKindThought { t.Fatalf("thought kind = %#v, want %q", got, MessageKindThought) } + if got := payload[PayloadKeyModelName]; got != "gpt-5.4-mini" { + t.Fatalf("thought model_name = %#v, want %q", got, "gpt-5.4-mini") + } if got := payload["message_id"]; got == "msg-progress" || got == nil || got == "" { t.Fatalf("thought message_id = %#v, want new non-progress id", got) } @@ -151,6 +159,9 @@ func TestSend_ThoughtMessageDoesNotFinalizeTrackedToolFeedback(t *testing.T) { Context: bus.InboundContext{ Channel: "pico", ChatID: "pico:sess-1", + Raw: map[string]string{ + PayloadKeyModelName: "gpt-5.4", + }, }, ContextUsage: &bus.ContextUsage{ UsedTokens: 321, @@ -174,6 +185,9 @@ func TestSend_ThoughtMessageDoesNotFinalizeTrackedToolFeedback(t *testing.T) { if got := payload[PayloadKeyContent]; got != "final reply" { t.Fatalf("final content = %#v, want %q", got, "final reply") } + if got := payload[PayloadKeyModelName]; got != "gpt-5.4" { + t.Fatalf("final model_name = %#v, want %q", got, "gpt-5.4") + } rawUsage, ok := payload["context_usage"].(map[string]any) if !ok { t.Fatalf("final context_usage = %#v, want map payload", payload["context_usage"]) @@ -193,6 +207,54 @@ func TestSend_ThoughtMessageDoesNotFinalizeTrackedToolFeedback(t *testing.T) { } } +func TestSend_ToolCallsMessageIncludesModelName(t *testing.T) { + ch := newTestPicoChannel(t) + + if err := ch.Start(context.Background()); err != nil { + t.Fatalf("Start() error = %v", err) + } + defer ch.Stop(context.Background()) + + clientConn, received, cleanup := newTestPicoWebSocket(t) + defer cleanup() + ch.addConnForTest(&picoConn{id: "conn-1", conn: clientConn, sessionID: "sess-1"}) + + if _, err := ch.Send(context.Background(), bus.OutboundMessage{ + ChatID: "pico:sess-1", + Content: "", + Context: bus.InboundContext{ + Channel: "pico", + ChatID: "pico:sess-1", + Raw: map[string]string{ + "message_kind": MessageKindToolCalls, + PayloadKeyModelName: "gpt-5.4", + PayloadKeyToolCalls: `[{"id":"call_1","type":"function","function":{"name":"read_file","arguments":"{\"path\":\"README.md\"}"}}]`, + }, + }, + }); err != nil { + t.Fatalf("Send(tool_calls) error = %v", err) + } + + select { + case msg := <-received: + if msg.Type != TypeMessageCreate { + t.Fatalf("tool_calls message type = %q, want %q", msg.Type, TypeMessageCreate) + } + payload := msg.Payload + if got := payload[PayloadKeyKind]; got != MessageKindToolCalls { + t.Fatalf("tool_calls kind = %#v, want %q", got, MessageKindToolCalls) + } + if got := payload[PayloadKeyModelName]; got != "gpt-5.4" { + t.Fatalf("tool_calls model_name = %#v, want %q", got, "gpt-5.4") + } + if _, ok := payload[PayloadKeyToolCalls].([]any); !ok { + t.Fatalf("tool_calls payload = %#v, want parsed array", payload[PayloadKeyToolCalls]) + } + case <-time.After(time.Second): + t.Fatal("expected tool_calls message to be delivered") + } +} + func TestSendPlaceholder_EmitsNormalMessageWithoutKind(t *testing.T) { ch := newTestPicoChannel(t) ch.bc.Placeholder.Enabled = true @@ -257,6 +319,9 @@ func TestBeginStream_CreatesAndUpdatesSameMessage(t *testing.T) { if err != nil { t.Fatalf("BeginStream() error = %v", err) } + if setter, ok := streamer.(interface{ SetModelName(modelName string) }); ok { + setter.SetModelName("gpt-5.4") + } if err := streamer.Update(context.Background(), "hello"); err != nil { t.Fatalf("Update(first) error = %v", err) } @@ -271,6 +336,9 @@ func TestBeginStream_CreatesAndUpdatesSameMessage(t *testing.T) { if got := first.Payload[PayloadKeyContent]; got != "hello" { t.Fatalf("first content = %#v, want hello", got) } + if got := first.Payload[PayloadKeyModelName]; got != "gpt-5.4" { + t.Fatalf("first model_name = %#v, want %q", got, "gpt-5.4") + } rawStreamer := streamer.(*picoStreamer) rawStreamer.mu.Lock() @@ -290,6 +358,9 @@ func TestBeginStream_CreatesAndUpdatesSameMessage(t *testing.T) { if got := second.Payload[PayloadKeyContent]; got != secondContent { t.Fatalf("second content = %#v, want %q", got, secondContent) } + if got := second.Payload[PayloadKeyModelName]; got != "gpt-5.4" { + t.Fatalf("second model_name = %#v, want %q", got, "gpt-5.4") + } } func TestBeginStream_DefaultStreamingShowsSmallIncrements(t *testing.T) { @@ -355,6 +426,9 @@ func TestBeginStream_StreamsReasoningAsThoughtUpdates(t *testing.T) { if !ok { t.Fatal("pico stream should support reasoning updates") } + if setter, ok := streamer.(interface{ SetModelName(modelName string) }); ok { + setter.SetModelName("gpt-5.4-mini") + } if err := reasoningStreamer.UpdateReasoning(context.Background(), "thinking"); err != nil { t.Fatalf("UpdateReasoning(first) error = %v", err) } @@ -372,6 +446,9 @@ func TestBeginStream_StreamsReasoningAsThoughtUpdates(t *testing.T) { if got := first.Payload[PayloadKeyContent]; got != "thinking" { t.Fatalf("first content = %#v, want thinking", got) } + if got := first.Payload[PayloadKeyModelName]; got != "gpt-5.4-mini" { + t.Fatalf("first model_name = %#v, want %q", got, "gpt-5.4-mini") + } if err := reasoningStreamer.UpdateReasoning(context.Background(), "thinking more"); err != nil { t.Fatalf("UpdateReasoning(second) error = %v", err) @@ -389,6 +466,9 @@ func TestBeginStream_StreamsReasoningAsThoughtUpdates(t *testing.T) { if got := second.Payload[PayloadKeyContent]; got != "thinking more" { t.Fatalf("second content = %#v, want thinking more", got) } + if got := second.Payload[PayloadKeyModelName]; got != "gpt-5.4-mini" { + t.Fatalf("second model_name = %#v, want %q", got, "gpt-5.4-mini") + } } func TestBeginStream_ThrottlesIntermediateUpdatesAndFinalFlushes(t *testing.T) { @@ -473,6 +553,9 @@ func TestBeginStream_FinalizeIncludesContextUsage(t *testing.T) { if err != nil { t.Fatalf("BeginStream() error = %v", err) } + if setter, ok := streamer.(interface{ SetModelName(modelName string) }); ok { + setter.SetModelName("gpt-5.4") + } if err := streamer.Update(context.Background(), "partial"); err != nil { t.Fatalf("Update() error = %v", err) } @@ -501,6 +584,9 @@ func TestBeginStream_FinalizeIncludesContextUsage(t *testing.T) { if got := final.Payload["message_id"]; got != msgID { t.Fatalf("final message_id = %#v, want %q", got, msgID) } + if got := final.Payload[PayloadKeyModelName]; got != "gpt-5.4" { + t.Fatalf("final model_name = %#v, want %q", got, "gpt-5.4") + } rawUsage, ok := final.Payload["context_usage"].(map[string]any) if !ok { t.Fatalf("final context_usage = %#v, want map", final.Payload["context_usage"]) diff --git a/pkg/channels/pico/protocol.go b/pkg/channels/pico/protocol.go index 320c44d16..0c809430d 100644 --- a/pkg/channels/pico/protocol.go +++ b/pkg/channels/pico/protocol.go @@ -27,6 +27,7 @@ const ( PayloadKeyKind = "kind" PayloadKeyPlaceholder = "placeholder" PayloadKeyToolCalls = "tool_calls" + PayloadKeyModelName = "model_name" MessageKindThought = "thought" MessageKindToolCalls = "tool_calls" diff --git a/pkg/memory/jsonl_test.go b/pkg/memory/jsonl_test.go index 3a7b98130..c77a4393e 100644 --- a/pkg/memory/jsonl_test.go +++ b/pkg/memory/jsonl_test.go @@ -130,6 +130,32 @@ func TestAddFullMessage_WithToolCalls(t *testing.T) { } } +func TestAddFullMessage_PreservesModelName(t *testing.T) { + store := newTestStore(t) + ctx := context.Background() + + msg := providers.Message{ + Role: "assistant", + Content: "done", + ModelName: "gpt-5.4-mini", + } + + if err := store.AddFullMessage(ctx, "model-name", msg); err != nil { + t.Fatalf("AddFullMessage: %v", err) + } + + history, err := store.GetHistory(ctx, "model-name") + if err != nil { + t.Fatalf("GetHistory: %v", err) + } + if len(history) != 1 { + t.Fatalf("expected 1, got %d", len(history)) + } + if history[0].ModelName != "gpt-5.4-mini" { + t.Fatalf("ModelName = %q, want %q", history[0].ModelName, "gpt-5.4-mini") + } +} + func TestAddFullMessage_ToolCallID(t *testing.T) { store := newTestStore(t) ctx := context.Background() diff --git a/pkg/providers/fallback.go b/pkg/providers/fallback.go index 36092105b..ca0e652d1 100644 --- a/pkg/providers/fallback.go +++ b/pkg/providers/fallback.go @@ -17,6 +17,7 @@ type FallbackChain struct { type FallbackCandidate struct { Provider string Model string + DisplayName string // optional configured alias/raw model label for persistence/UI RPM int // requests per minute; 0 means unrestricted IdentityKey string // optional stable config identity for cooldown/rate limiting } @@ -32,10 +33,11 @@ func (c FallbackCandidate) StableKey() string { // FallbackResult contains the successful response and metadata about all attempts. type FallbackResult struct { - Response *LLMResponse - Provider string - Model string - Attempts []FallbackAttempt + Response *LLMResponse + Provider string + Model string + IdentityKey string + Attempts []FallbackAttempt } // FallbackAttempt records one attempt in the fallback chain. @@ -85,8 +87,9 @@ func ResolveCandidatesWithLookup( } seen[key] = true candidates = append(candidates, FallbackCandidate{ - Provider: ref.Provider, - Model: ref.Model, + Provider: ref.Provider, + Model: ref.Model, + DisplayName: candidateRaw, }) } @@ -187,6 +190,7 @@ func (fc *FallbackChain) Execute( result.Response = resp result.Provider = candidate.Provider result.Model = candidate.Model + result.IdentityKey = candidate.StableKey() return result, nil } @@ -305,6 +309,7 @@ func (fc *FallbackChain) ExecuteImage( result.Response = resp result.Provider = candidate.Provider result.Model = candidate.Model + result.IdentityKey = candidate.StableKey() return result, nil } diff --git a/pkg/providers/protocoltypes/types.go b/pkg/providers/protocoltypes/types.go index 37a929a58..650bcb287 100644 --- a/pkg/providers/protocoltypes/types.go +++ b/pkg/providers/protocoltypes/types.go @@ -86,6 +86,7 @@ type Attachment struct { type Message struct { Role string `json:"role"` Content string `json:"content"` + ModelName string `json:"model_name,omitempty"` Media []string `json:"media,omitempty"` Attachments []Attachment `json:"attachments,omitempty"` ReasoningContent string `json:"reasoning_content,omitempty"` diff --git a/pkg/seahorse/schema.go b/pkg/seahorse/schema.go index 5b67fe9e0..97e638c91 100644 --- a/pkg/seahorse/schema.go +++ b/pkg/seahorse/schema.go @@ -46,6 +46,7 @@ func runSchema(db *sql.DB) error { conversation_id INTEGER NOT NULL REFERENCES conversations(conversation_id), role TEXT NOT NULL, content TEXT NOT NULL DEFAULT '', + model_name TEXT NOT NULL DEFAULT '', reasoning_content TEXT NOT NULL DEFAULT '', token_count INTEGER NOT NULL DEFAULT 0, created_at TEXT NOT NULL DEFAULT (datetime('now')) @@ -162,6 +163,9 @@ func runSchema(db *sql.DB) error { if err := ensureMessagesReasoningContentColumn(db); err != nil { return err } + if err := ensureMessagesModelNameColumn(db); err != nil { + return err + } return nil } @@ -180,6 +184,21 @@ func ensureMessagesReasoningContentColumn(db *sql.DB) error { return nil } +func ensureMessagesModelNameColumn(db *sql.DB) error { + hasColumn, err := tableHasColumn(db, "messages", "model_name") + if err != nil { + return fmt.Errorf("check messages.model_name: %w", err) + } + if hasColumn { + return nil + } + + if _, err := db.Exec(`ALTER TABLE messages ADD COLUMN model_name TEXT NOT NULL DEFAULT ''`); err != nil { + return fmt.Errorf("add messages.model_name: %w", err) + } + return nil +} + func tableHasColumn(db *sql.DB, tableName, columnName string) (bool, error) { rows, err := db.Query(fmt.Sprintf(`PRAGMA table_info(%s)`, tableName)) if err != nil { diff --git a/pkg/seahorse/schema_test.go b/pkg/seahorse/schema_test.go index 943b742b2..4618eeff4 100644 --- a/pkg/seahorse/schema_test.go +++ b/pkg/seahorse/schema_test.go @@ -138,6 +138,37 @@ func TestRunSchemaAddsMessagesReasoningContentColumn(t *testing.T) { } } +func TestRunSchemaAddsMessagesModelNameColumn(t *testing.T) { + db := openTestDB(t) + + _, err := db.Exec(`CREATE TABLE messages ( + message_id INTEGER PRIMARY KEY AUTOINCREMENT, + conversation_id INTEGER NOT NULL, + role TEXT NOT NULL, + content TEXT NOT NULL DEFAULT '', + reasoning_content TEXT NOT NULL DEFAULT '', + token_count INTEGER NOT NULL DEFAULT 0, + created_at TEXT NOT NULL DEFAULT (datetime('now')) + )`) + if err != nil { + t.Fatalf("create legacy messages table: %v", err) + } + + err = runSchema(db) + if err != nil { + t.Fatalf("runSchema: %v", err) + } + + var count int + err = db.QueryRow(`SELECT count(*) FROM pragma_table_info('messages') WHERE name = 'model_name'`).Scan(&count) + if err != nil { + t.Fatalf("query pragma_table_info: %v", err) + } + if count != 1 { + t.Fatalf("model_name column count = %d, want 1", count) + } +} + func TestMigrationConversationUnique(t *testing.T) { db := openTestDB(t) if err := runSchema(db); err != nil { diff --git a/pkg/seahorse/short_engine.go b/pkg/seahorse/short_engine.go index 0a8175617..d275da516 100644 --- a/pkg/seahorse/short_engine.go +++ b/pkg/seahorse/short_engine.go @@ -258,6 +258,7 @@ func (e *Engine) Ingest(ctx context.Context, sessionKey string, messages []Messa conv.ConversationID, msg.Role, msg.Parts, + msg.ModelName, msg.ReasoningContent, msg.TokenCount, ) @@ -267,6 +268,7 @@ func (e *Engine) Ingest(ctx context.Context, sessionKey string, messages []Messa conv.ConversationID, msg.Role, msg.Content, + msg.ModelName, msg.ReasoningContent, msg.TokenCount, ) @@ -431,6 +433,31 @@ func (e *Engine) Bootstrap(ctx context.Context, sessionKey string, messages []Me return fmt.Errorf("bootstrap: get messages: %w", err) } + // Migration repair path: old SeaHorse rows may be missing reasoning_content + // even though the canonical JSONL history already has it. Backfill those + // rows in place so we do not treat this as edited history and leave stale + // summaries/context behind after a partial raw-message rebuild. + repairedReasoning, err := e.repairBootstrapReasoningContent(ctx, dbMsgs, messages) + if err != nil { + return fmt.Errorf("bootstrap: repair reasoning_content: %w", err) + } + repairedModelName, err := e.repairBootstrapModelName(ctx, dbMsgs, messages) + if err != nil { + return fmt.Errorf("bootstrap: repair model_name: %w", err) + } + if (repairedReasoning || repairedModelName) && len(dbMsgs) == len(messages) { + matched := true + for i := range messages { + if !messageMatches(dbMsgs[i], messages[i]) { + matched = false + break + } + } + if matched { + return nil + } + } + // Fast path: DB has same count and exact match → no-op if len(dbMsgs) == len(messages) { matched := true @@ -445,16 +472,6 @@ func (e *Engine) Bootstrap(ctx context.Context, sessionKey string, messages []Me } } - // Migration repair path: old SeaHorse rows may be missing reasoning_content - // even though the canonical JSONL history already has it. Backfill those - // rows in place so we do not treat this as edited history and leave stale - // summaries/context behind after a partial raw-message rebuild. - if repaired, err := e.repairBootstrapReasoningContent(ctx, dbMsgs, messages); err != nil { - return fmt.Errorf("bootstrap: repair reasoning_content: %w", err) - } else if repaired && len(dbMsgs) == len(messages) { - return nil - } - // Find longest matching prefix from the start anchor := -1 compareLen := min(len(dbMsgs), len(messages)) @@ -465,14 +482,16 @@ func (e *Engine) Bootstrap(ctx context.Context, sessionKey string, messages []Me } else { // Mismatch detected - log details and rebuild logger.InfoCF("seahorse", "bootstrap: mismatch detected", map[string]any{ - "conv_id": conv.ConversationID, - "index": i, - "db_role": dbMsgs[i].Role, - "db_content": truncate(dbMsgs[i].Content, 50), - "db_parts": len(dbMsgs[i].Parts), - "msg_role": messages[i].Role, - "msg_content": truncate(messages[i].Content, 50), - "msg_parts": len(messages[i].Parts), + "conv_id": conv.ConversationID, + "index": i, + "db_role": dbMsgs[i].Role, + "db_content": truncate(dbMsgs[i].Content, 50), + "db_parts": len(dbMsgs[i].Parts), + "db_model_name": dbMsgs[i].ModelName, + "msg_role": messages[i].Role, + "msg_content": truncate(messages[i].Content, 50), + "msg_parts": len(messages[i].Parts), + "msg_model_name": messages[i].ModelName, }) break } @@ -559,7 +578,7 @@ func (e *Engine) repairBootstrapReasoningContent(ctx context.Context, dbMsgs, me } for i := range overlap { - if !messageMatchesIgnoringReasoning(dbMsgs[i], messages[i]) { + if !messageMatchesIgnoringReasoningAndModelName(dbMsgs[i], messages[i]) { return false, nil } if dbMsgs[i].ReasoningContent == messages[i].ReasoningContent { @@ -596,6 +615,57 @@ func (e *Engine) repairBootstrapReasoningContent(ctx context.Context, dbMsgs, me return true, nil } +func (e *Engine) repairBootstrapModelName(ctx context.Context, dbMsgs, messages []Message) (bool, error) { + if len(dbMsgs) == 0 || len(messages) == 0 { + return false, nil + } + + overlap := min(len(messages), len(dbMsgs)) + + var updates []struct { + index int + messageID int64 + modelName string + } + + for i := range overlap { + if !messageMatchesIgnoringReasoningAndModelName(dbMsgs[i], messages[i]) { + return false, nil + } + if dbMsgs[i].ModelName == messages[i].ModelName { + continue + } + if messages[i].ModelName == "" { + return false, nil + } + updates = append(updates, struct { + index int + messageID int64 + modelName string + }{ + index: i, + messageID: dbMsgs[i].ID, + modelName: messages[i].ModelName, + }) + } + + if len(updates) == 0 { + return false, nil + } + + for _, update := range updates { + if err := e.store.UpdateMessageModelName(ctx, update.messageID, update.modelName); err != nil { + return false, err + } + dbMsgs[update.index].ModelName = update.modelName + } + + logger.InfoCF("seahorse", "bootstrap: repaired missing model_name", map[string]any{ + "messages": len(updates), + }) + return true, nil +} + // truncate shortens a string for logging. func truncate(s string, maxLen int) string { if len(s) <= maxLen { @@ -610,13 +680,20 @@ func truncate(s string, maxLen int) string { // For messages with Parts (tool_use, tool_result), compare Parts instead of Content // because structured messages are matched by their parts payload. func messageMatches(a, b Message) bool { - if a.Role != b.Role || a.ReasoningContent != b.ReasoningContent { + if a.Role != b.Role || a.ReasoningContent != b.ReasoningContent || a.ModelName != b.ModelName { return false } return messageMatchesIgnoringReasoning(a, b) } func messageMatchesIgnoringReasoning(a, b Message) bool { + if a.ModelName != b.ModelName { + return false + } + return messageMatchesIgnoringReasoningAndModelName(a, b) +} + +func messageMatchesIgnoringReasoningAndModelName(a, b Message) bool { if a.Role != b.Role { return false } diff --git a/pkg/seahorse/short_engine_test.go b/pkg/seahorse/short_engine_test.go index 2a5c6c5d8..337416f6f 100644 --- a/pkg/seahorse/short_engine_test.go +++ b/pkg/seahorse/short_engine_test.go @@ -25,6 +25,43 @@ func newTestEngine(t *testing.T) *Engine { } } +func prepareBootstrapRepairConversation( + t *testing.T, + eng *Engine, + ctx context.Context, + sessionKey string, +) (*Conversation, []Message) { + t.Helper() + + conv, err := eng.store.GetOrCreateConversation(ctx, sessionKey) + if err != nil { + t.Fatalf("GetOrCreateConversation: %v", err) + } + + userMsg, err := eng.store.AddMessage(ctx, conv.ConversationID, "user", "hello", 3) + if err != nil { + t.Fatalf("AddMessage user: %v", err) + } + + assistantMsg, err := eng.store.AddMessage(ctx, conv.ConversationID, "assistant", "world", 3) + if err != nil { + t.Fatalf("AddMessage assistant: %v", err) + } + + if err := eng.store.AppendContextMessages( + ctx, + conv.ConversationID, + []int64{userMsg.ID, assistantMsg.ID}, + ); err != nil { + t.Fatalf("AppendContextMessages: %v", err) + } + + return conv, []Message{ + {Role: "user", Content: "hello", TokenCount: 3}, + {Role: "assistant", Content: "world", TokenCount: 3}, + } +} + // --- compileSessionPattern --- func TestCompileSessionPattern(t *testing.T) { @@ -328,6 +365,7 @@ func TestEngineIngestPreservesReasoningContent(t *testing.T) { { Role: "assistant", Content: "world", + ModelName: "gpt-5.4-mini", ReasoningContent: "let me think this through", TokenCount: 4, }, @@ -353,6 +391,9 @@ func TestEngineIngestPreservesReasoningContent(t *testing.T) { "let me think this through", ) } + if stored[0].ModelName != "gpt-5.4-mini" { + t.Errorf("stored[0].ModelName = %q, want %q", stored[0].ModelName, "gpt-5.4-mini") + } result, err := eng.Assemble(ctx, "agent:reasoning", AssembleInput{Budget: 1000}) if err != nil { @@ -368,6 +409,140 @@ func TestEngineIngestPreservesReasoningContent(t *testing.T) { "let me think this through", ) } + if result.Messages[0].ModelName != "gpt-5.4-mini" { + t.Errorf("assembled model_name = %q, want %q", result.Messages[0].ModelName, "gpt-5.4-mini") + } +} + +func TestBootstrapRepairsMissingModelName(t *testing.T) { + eng := newTestEngine(t) + ctx := context.Background() + sessionKey := "agent:repair-model-name" + conv, msgs := prepareBootstrapRepairConversation(t, eng, ctx, sessionKey) + msgs[1].ModelName = "gpt-5.4" + + err := eng.Bootstrap(ctx, sessionKey, msgs) + if err != nil { + t.Fatalf("Bootstrap: %v", err) + } + + stored, err := eng.store.GetMessages(ctx, conv.ConversationID, 10, 0) + if err != nil { + t.Fatalf("GetMessages: %v", err) + } + if len(stored) != 2 { + t.Fatalf("stored messages = %d, want 2", len(stored)) + } + if stored[1].ModelName != "gpt-5.4" { + t.Fatalf("stored[1].ModelName = %q, want %q", stored[1].ModelName, "gpt-5.4") + } +} + +func TestBootstrapRepairsReasoningContentAndModelNameTogether(t *testing.T) { + eng := newTestEngine(t) + ctx := context.Background() + sessionKey := "agent:repair-both-fields" + + conv, err := eng.store.GetOrCreateConversation(ctx, sessionKey) + if err != nil { + t.Fatalf("GetOrCreateConversation: %v", err) + } + + userMsg, err := eng.store.AddMessage(ctx, conv.ConversationID, "user", "hello", 3) + if err != nil { + t.Fatalf("AddMessage user: %v", err) + } + + assistantMsg, err := eng.store.AddMessage(ctx, conv.ConversationID, "assistant", "world", 3) + if err != nil { + t.Fatalf("AddMessage assistant: %v", err) + } + + err = eng.store.AppendContextMessages(ctx, conv.ConversationID, []int64{userMsg.ID, assistantMsg.ID}) + if err != nil { + t.Fatalf("AppendContextMessages: %v", err) + } + + err = eng.Bootstrap(ctx, sessionKey, []Message{ + {Role: "user", Content: "hello", TokenCount: 3}, + { + Role: "assistant", + Content: "world", + ModelName: "gpt-5.4", + ReasoningContent: "let me think this through", + TokenCount: 3, + }, + }) + if err != nil { + t.Fatalf("Bootstrap: %v", err) + } + + stored, err := eng.store.GetMessages(ctx, conv.ConversationID, 10, 0) + if err != nil { + t.Fatalf("GetMessages: %v", err) + } + if len(stored) != 2 { + t.Fatalf("stored messages = %d, want 2", len(stored)) + } + if stored[1].ReasoningContent != "let me think this through" { + t.Fatalf("stored[1].ReasoningContent = %q, want %q", stored[1].ReasoningContent, "let me think this through") + } + if stored[1].ModelName != "gpt-5.4" { + t.Fatalf("stored[1].ModelName = %q, want %q", stored[1].ModelName, "gpt-5.4") + } +} + +func TestBootstrapRepairsIncorrectNonEmptyModelName(t *testing.T) { + eng := newTestEngine(t) + ctx := context.Background() + sessionKey := "agent:repair-wrong-model-name" + + conv, err := eng.store.GetOrCreateConversation(ctx, sessionKey) + if err != nil { + t.Fatalf("GetOrCreateConversation: %v", err) + } + + userMsg, err := eng.store.AddMessage(ctx, conv.ConversationID, "user", "hello", 3) + if err != nil { + t.Fatalf("AddMessage user: %v", err) + } + + assistantMsg, err := eng.store.AddMessageWithReasoning( + ctx, + conv.ConversationID, + "assistant", + "world", + "wrong-model", + "", + 3, + ) + if err != nil { + t.Fatalf("AddMessageWithReasoning assistant: %v", err) + } + + err = eng.store.AppendContextMessages(ctx, conv.ConversationID, []int64{userMsg.ID, assistantMsg.ID}) + if err != nil { + t.Fatalf("AppendContextMessages: %v", err) + } + + err = eng.Bootstrap(ctx, sessionKey, []Message{ + {Role: "user", Content: "hello", TokenCount: 3}, + {Role: "assistant", Content: "world", ModelName: "gpt-5.4", TokenCount: 3}, + }) + if err != nil { + t.Fatalf("Bootstrap: %v", err) + } + + stored, err := eng.store.GetMessages(ctx, conv.ConversationID, 10, 0) + if err != nil { + t.Fatalf("GetMessages: %v", err) + } + if len(stored) != 2 { + t.Fatalf("stored messages = %d, want 2", len(stored)) + } + if stored[1].ModelName != "gpt-5.4" { + t.Fatalf("stored[1].ModelName = %q, want %q", stored[1].ModelName, "gpt-5.4") + } } func TestEngineIngestWithPartsPreservesReasoningContent(t *testing.T) { @@ -620,35 +795,10 @@ func TestBootstrapRepairsMissingReasoningContent(t *testing.T) { eng := newTestEngine(t) ctx := context.Background() sessionKey := "agent:repair-reasoning" + conv, msgs := prepareBootstrapRepairConversation(t, eng, ctx, sessionKey) + msgs[1].ReasoningContent = "let me think this through" - conv, err := eng.store.GetOrCreateConversation(ctx, sessionKey) - if err != nil { - t.Fatalf("GetOrCreateConversation: %v", err) - } - - userMsg, err := eng.store.AddMessage(ctx, conv.ConversationID, "user", "hello", 3) - if err != nil { - t.Fatalf("AddMessage user: %v", err) - } - - assistantMsg, err := eng.store.AddMessage(ctx, conv.ConversationID, "assistant", "world", 3) - if err != nil { - t.Fatalf("AddMessage assistant: %v", err) - } - - err = eng.store.AppendContextMessages( - ctx, - conv.ConversationID, - []int64{userMsg.ID, assistantMsg.ID}, - ) - if err != nil { - t.Fatalf("AppendContextMessages: %v", err) - } - - err = eng.Bootstrap(ctx, sessionKey, []Message{ - {Role: "user", Content: "hello", TokenCount: 3}, - {Role: "assistant", Content: "world", ReasoningContent: "let me think this through", TokenCount: 3}, - }) + err := eng.Bootstrap(ctx, sessionKey, msgs) if err != nil { t.Fatalf("Bootstrap: %v", err) } diff --git a/pkg/seahorse/store.go b/pkg/seahorse/store.go index 0edbbd128..b5e32e89d 100644 --- a/pkg/seahorse/store.go +++ b/pkg/seahorse/store.go @@ -162,19 +162,25 @@ func (s *Store) getMessageTimeRange(ctx context.Context, convID int64) (time.Tim // AddMessage appends a message to a conversation. func (s *Store) AddMessage(ctx context.Context, convID int64, role, content string, tokenCount int) (*Message, error) { - return s.AddMessageWithReasoning(ctx, convID, role, content, "", tokenCount) + return s.AddMessageWithReasoning(ctx, convID, role, content, "", "", tokenCount) } // AddMessageWithReasoning appends a message with reasoning content to a conversation. func (s *Store) AddMessageWithReasoning( ctx context.Context, convID int64, - role, content, reasoningContent string, + role, content, modelName, reasoningContent string, tokenCount int, ) (*Message, error) { - result, err := s.db.ExecContext(ctx, - "INSERT INTO messages (conversation_id, role, content, reasoning_content, token_count) VALUES (?, ?, ?, ?, ?)", - convID, role, content, reasoningContent, tokenCount, + result, err := s.db.ExecContext( + ctx, + "INSERT INTO messages (conversation_id, role, content, model_name, reasoning_content, token_count) VALUES (?, ?, ?, ?, ?, ?)", + convID, + role, + content, + modelName, + reasoningContent, + tokenCount, ) if err != nil { return nil, fmt.Errorf("add message: %w", err) @@ -185,6 +191,7 @@ func (s *Store) AddMessageWithReasoning( ConversationID: convID, Role: role, Content: content, + ModelName: modelName, ReasoningContent: reasoningContent, TokenCount: tokenCount, }, nil @@ -224,7 +231,7 @@ func (s *Store) AddMessageWithParts( parts []MessagePart, tokenCount int, ) (*Message, error) { - return s.AddMessageWithPartsAndReasoning(ctx, convID, role, parts, "", tokenCount) + return s.AddMessageWithPartsAndReasoning(ctx, convID, role, parts, "", "", tokenCount) } // AddMessageWithPartsAndReasoning adds a message with structured parts and reasoning content. @@ -233,6 +240,7 @@ func (s *Store) AddMessageWithPartsAndReasoning( convID int64, role string, parts []MessagePart, + modelName string, reasoningContent string, tokenCount int, ) (*Message, error) { @@ -245,9 +253,15 @@ func (s *Store) AddMessageWithPartsAndReasoning( // Derive readable content from Parts for FTS5 indexing and summary formatting readableContent := partsToReadableContent(parts) - result, err := tx.ExecContext(ctx, - "INSERT INTO messages (conversation_id, role, content, reasoning_content, token_count) VALUES (?, ?, ?, ?, ?)", - convID, role, readableContent, reasoningContent, tokenCount, + result, err := tx.ExecContext( + ctx, + "INSERT INTO messages (conversation_id, role, content, model_name, reasoning_content, token_count) VALUES (?, ?, ?, ?, ?, ?)", + convID, + role, + readableContent, + modelName, + reasoningContent, + tokenCount, ) if err != nil { return nil, fmt.Errorf("add message: %w", err) @@ -282,6 +296,7 @@ func (s *Store) AddMessageWithPartsAndReasoning( ID: msgID, ConversationID: convID, Role: role, + ModelName: modelName, ReasoningContent: reasoningContent, TokenCount: tokenCount, Parts: make([]MessagePart, len(parts)), @@ -295,7 +310,7 @@ func (s *Store) AddMessageWithPartsAndReasoning( // GetMessages retrieves messages for a conversation. func (s *Store) GetMessages(ctx context.Context, convID int64, limit int, beforeID int64) ([]Message, error) { - query := "SELECT message_id, conversation_id, role, content, reasoning_content, token_count, created_at FROM messages WHERE conversation_id = ?" + query := "SELECT message_id, conversation_id, role, content, model_name, reasoning_content, token_count, created_at FROM messages WHERE conversation_id = ?" args := []any{convID} if beforeID > 0 { query += " AND message_id < ?" @@ -322,6 +337,7 @@ func (s *Store) GetMessages(ctx context.Context, convID int64, limit int, before &msg.ConversationID, &msg.Role, &msg.Content, + &msg.ModelName, &msg.ReasoningContent, &msg.TokenCount, &createdAt, @@ -362,9 +378,9 @@ func (s *Store) GetMessageByID(ctx context.Context, messageID int64) (*Message, var createdAt string err := s.db.QueryRowContext( ctx, - "SELECT message_id, conversation_id, role, content, reasoning_content, token_count, created_at FROM messages WHERE message_id = ?", + "SELECT message_id, conversation_id, role, content, model_name, reasoning_content, token_count, created_at FROM messages WHERE message_id = ?", messageID, - ).Scan(&msg.ID, &msg.ConversationID, &msg.Role, &msg.Content, &msg.ReasoningContent, &msg.TokenCount, &createdAt) + ).Scan(&msg.ID, &msg.ConversationID, &msg.Role, &msg.Content, &msg.ModelName, &msg.ReasoningContent, &msg.TokenCount, &createdAt) if err == sql.ErrNoRows { return nil, fmt.Errorf("message %d not found", messageID) } @@ -398,6 +414,27 @@ func (s *Store) UpdateMessageReasoningContent(ctx context.Context, messageID int return nil } +func (s *Store) UpdateMessageModelName(ctx context.Context, messageID int64, modelName string) error { + result, err := s.db.ExecContext( + ctx, + "UPDATE messages SET model_name = ? WHERE message_id = ?", + modelName, + messageID, + ) + if err != nil { + return fmt.Errorf("update message model_name: %w", err) + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("update message model_name rows affected: %w", err) + } + if rowsAffected == 0 { + return fmt.Errorf("message %d not found", messageID) + } + return nil +} + func (s *Store) loadMessageParts(ctx context.Context, msgID int64) ([]MessagePart, error) { rows, err := s.db.QueryContext(ctx, `SELECT part_id, message_id, type, text, name, arguments, tool_call_id, media_uri, mime_type @@ -581,8 +618,9 @@ func (s *Store) LinkSummaryToMessages(ctx context.Context, summaryID string, mes // GetSummarySourceMessages retrieves source messages for a summary. func (s *Store) GetSummarySourceMessages(ctx context.Context, summaryID string) ([]Message, error) { - rows, err := s.db.QueryContext(ctx, - `SELECT m.message_id, m.conversation_id, m.role, m.content, m.reasoning_content, m.token_count, m.created_at + rows, err := s.db.QueryContext( + ctx, + `SELECT m.message_id, m.conversation_id, m.role, m.content, m.model_name, m.reasoning_content, m.token_count, m.created_at FROM summary_messages sm JOIN messages m ON m.message_id = sm.message_id WHERE sm.summary_id = ? @@ -603,6 +641,7 @@ func (s *Store) GetSummarySourceMessages(ctx context.Context, summaryID string) &msg.ConversationID, &msg.Role, &msg.Content, + &msg.ModelName, &msg.ReasoningContent, &msg.TokenCount, &createdAt, diff --git a/pkg/seahorse/store_test.go b/pkg/seahorse/store_test.go index 67bed1c11..4ed2bb3bb 100644 --- a/pkg/seahorse/store_test.go +++ b/pkg/seahorse/store_test.go @@ -210,6 +210,7 @@ func TestStoreAddAndGetMessagesWithReasoningContent(t *testing.T) { conv.ConversationID, "assistant", "hello world", + "gpt-5.4-mini", "let me think", 5, ) @@ -219,6 +220,9 @@ func TestStoreAddAndGetMessagesWithReasoningContent(t *testing.T) { if msg.ReasoningContent != "let me think" { t.Fatalf("ReasoningContent = %q, want %q", msg.ReasoningContent, "let me think") } + if msg.ModelName != "gpt-5.4-mini" { + t.Fatalf("ModelName = %q, want %q", msg.ModelName, "gpt-5.4-mini") + } msgs, err := s.GetMessages(ctx, conv.ConversationID, 10, 0) if err != nil { @@ -230,6 +234,9 @@ func TestStoreAddAndGetMessagesWithReasoningContent(t *testing.T) { if msgs[0].ReasoningContent != "let me think" { t.Errorf("ReasoningContent = %q, want %q", msgs[0].ReasoningContent, "let me think") } + if msgs[0].ModelName != "gpt-5.4-mini" { + t.Errorf("ModelName = %q, want %q", msgs[0].ModelName, "gpt-5.4-mini") + } found, err := s.GetMessageByID(ctx, msg.ID) if err != nil { @@ -238,6 +245,9 @@ func TestStoreAddAndGetMessagesWithReasoningContent(t *testing.T) { if found.ReasoningContent != "let me think" { t.Errorf("GetMessageByID ReasoningContent = %q, want %q", found.ReasoningContent, "let me think") } + if found.ModelName != "gpt-5.4-mini" { + t.Errorf("GetMessageByID ModelName = %q, want %q", found.ModelName, "gpt-5.4-mini") + } } func TestStoreAddMessageWithParts(t *testing.T) { @@ -288,6 +298,7 @@ func TestStoreAddMessageWithPartsAndReasoningContent(t *testing.T) { conv.ConversationID, "assistant", parts, + "gpt-5.4", "need to inspect the file first", 10, ) @@ -309,6 +320,9 @@ func TestStoreAddMessageWithPartsAndReasoningContent(t *testing.T) { "need to inspect the file first", ) } + if msgs[0].ModelName != "gpt-5.4" { + t.Errorf("ModelName = %q, want %q", msgs[0].ModelName, "gpt-5.4") + } } func TestStoreGetMessageCount(t *testing.T) { diff --git a/pkg/seahorse/types.go b/pkg/seahorse/types.go index 2bc7f931f..af323d2be 100644 --- a/pkg/seahorse/types.go +++ b/pkg/seahorse/types.go @@ -22,6 +22,7 @@ type Message struct { ConversationID int64 `json:"conversationId"` Role string `json:"role"` Content string `json:"content"` + ModelName string `json:"modelName,omitempty"` ReasoningContent string `json:"reasoningContent,omitempty"` TokenCount int `json:"tokenCount"` CreatedAt time.Time `json:"createdAt"` @@ -135,6 +136,7 @@ func EstimateMessageTokens(msg Message) int { pm := providers.Message{ Role: msg.Role, Content: msg.Content, + ModelName: msg.ModelName, ReasoningContent: msg.ReasoningContent, } diff --git a/pkg/session/jsonl_backend_test.go b/pkg/session/jsonl_backend_test.go index 0b79ad84d..6f67109d6 100644 --- a/pkg/session/jsonl_backend_test.go +++ b/pkg/session/jsonl_backend_test.go @@ -66,6 +66,25 @@ func TestJSONLBackend_AddFullMessage(t *testing.T) { } } +func TestJSONLBackend_AddFullMessage_PreservesModelName(t *testing.T) { + b := newBackend(t) + + msg := providers.Message{ + Role: "assistant", + Content: "done", + ModelName: "gpt-5.4-mini", + } + b.AddFullMessage("s1", msg) + + history := b.GetHistory("s1") + if len(history) != 1 { + t.Fatalf("got %d, want 1", len(history)) + } + if history[0].ModelName != "gpt-5.4-mini" { + t.Fatalf("ModelName = %q, want %q", history[0].ModelName, "gpt-5.4-mini") + } +} + func TestJSONLBackend_Summary(t *testing.T) { b := newBackend(t) diff --git a/web/backend/api/session.go b/web/backend/api/session.go index cc18ee6e1..e221386ca 100644 --- a/web/backend/api/session.go +++ b/web/backend/api/session.go @@ -50,6 +50,7 @@ type sessionChatMessage struct { Role string `json:"role"` Content string `json:"content"` Kind string `json:"kind,omitempty"` + ModelName string `json:"model_name,omitempty"` Media []string `json:"media,omitempty"` Attachments []sessionChatAttachment `json:"attachments,omitempty"` ToolCalls []utils.VisibleToolCall `json:"tool_calls,omitempty"` @@ -510,6 +511,7 @@ func sessionTranscriptMessages( chatMsg := sessionChatMessage{ Role: "user", Content: msg.Content, + ModelName: msg.ModelName, Media: append([]string(nil), msg.Media...), Attachments: attachments, } @@ -529,9 +531,10 @@ func sessionTranscriptMessages( toolCallsMsg, hasToolCallsMsg := assistantToolCallsMessage( msg.ToolCalls, + msg.ModelName, toolFeedbackMaxArgsLength, ) - visibleToolMessages := visibleAssistantToolMessages(msg.ToolCalls) + visibleToolMessages := visibleAssistantToolMessages(msg.ToolCalls, msg.ModelName) // Pico web chat can persist both visible `message` tool output and a // later plain assistant reply in the same turn. Hide only the fixed @@ -556,6 +559,7 @@ func sessionTranscriptMessages( chatMsg := sessionChatMessage{ Role: "assistant", Content: content, + ModelName: msg.ModelName, Media: append([]string(nil), msg.Media...), Attachments: attachments, } @@ -682,14 +686,16 @@ func assistantThoughtMessage(msg providers.Message) (sessionChatMessage, bool) { return sessionChatMessage{}, false } return sessionChatMessage{ - Role: "assistant", - Content: reasoning, - Kind: "thought", + Role: "assistant", + Content: reasoning, + Kind: "thought", + ModelName: msg.ModelName, }, true } func assistantToolCallsMessage( toolCalls []providers.ToolCall, + modelName string, toolFeedbackMaxArgsLength int, ) (sessionChatMessage, bool) { if len(toolCalls) == 0 { @@ -707,6 +713,7 @@ func assistantToolCallsMessage( return sessionChatMessage{ Role: "assistant", Kind: "tool_calls", + ModelName: modelName, ToolCalls: visibleToolCalls, }, true } @@ -718,7 +725,7 @@ func visibleAssistantToolArgsPreview( return utils.VisibleToolCallArgumentsPreview(tc, toolFeedbackMaxArgsLength) } -func visibleAssistantToolMessages(toolCalls []providers.ToolCall) []sessionChatMessage { +func visibleAssistantToolMessages(toolCalls []providers.ToolCall, modelName string) []sessionChatMessage { if len(toolCalls) == 0 { return nil } @@ -734,8 +741,9 @@ func visibleAssistantToolMessages(toolCalls []providers.ToolCall) []sessionChatM continue } messages = append(messages, sessionChatMessage{ - Role: "assistant", - Content: content, + Role: "assistant", + Content: content, + ModelName: modelName, }) } diff --git a/web/backend/api/session_test.go b/web/backend/api/session_test.go index 760935db7..8604f1df6 100644 --- a/web/backend/api/session_test.go +++ b/web/backend/api/session_test.go @@ -564,7 +564,7 @@ func TestHandleGetSession_ReconstructsThoughtFromAssistantReasoningContent(t *te sessionKey := picoSessionPrefix + "detail-reasoning-content" for _, msg := range []providers.Message{ {Role: "user", Content: "hello"}, - {Role: "assistant", Content: "final visible answer", ReasoningContent: "internal chain of thought"}, + {Role: "assistant", Content: "final visible answer", ModelName: "gpt-5.4", ReasoningContent: "internal chain of thought"}, } { if err := store.AddFullMessage(nil, sessionKey, msg); err != nil { t.Fatalf("AddFullMessage() error = %v", err) @@ -597,9 +597,15 @@ func TestHandleGetSession_ReconstructsThoughtFromAssistantReasoningContent(t *te resp.Messages[1].Kind != "thought" { t.Fatalf("thought message = %#v, want assistant thought/internal chain of thought", resp.Messages[1]) } + if resp.Messages[1].ModelName != "gpt-5.4" { + t.Fatalf("thought model_name = %q, want %q", resp.Messages[1].ModelName, "gpt-5.4") + } if resp.Messages[2].Role != "assistant" || resp.Messages[2].Content != "final visible answer" { t.Fatalf("final message = %#v, want assistant/final visible answer", resp.Messages[2]) } + if resp.Messages[2].ModelName != "gpt-5.4" { + t.Fatalf("final model_name = %q, want %q", resp.Messages[2].ModelName, "gpt-5.4") + } } func TestHandleGetSession_ReconstructsRefreshMatrixForThoughtAndToolSummary(t *testing.T) { @@ -725,8 +731,9 @@ func TestHandleGetSession_ReconstructsVisibleMessageToolOutputWithoutDuplicateSu for _, msg := range []providers.Message{ {Role: "user", Content: "test"}, { - Role: "assistant", - Content: "", + Role: "assistant", + Content: "", + ModelName: "gpt-5.4-mini", ToolCalls: []providers.ToolCall{ { ID: "call_1", @@ -771,9 +778,15 @@ func TestHandleGetSession_ReconstructsVisibleMessageToolOutputWithoutDuplicateSu t.Fatalf("first message = %#v, want user/test", resp.Messages[0]) } assertVisibleToolCallMessage(t, resp.Messages[1], "message") + if resp.Messages[1].ModelName != "gpt-5.4-mini" { + t.Fatalf("tool_calls model_name = %q, want %q", resp.Messages[1].ModelName, "gpt-5.4-mini") + } if resp.Messages[2].Role != "assistant" || resp.Messages[2].Content != "visible tool output" { t.Fatalf("assistant message = %#v, want visible tool output", resp.Messages[2]) } + if resp.Messages[2].ModelName != "gpt-5.4-mini" { + t.Fatalf("visible tool output model_name = %q, want %q", resp.Messages[2].ModelName, "gpt-5.4-mini") + } } func TestHandleGetSession_PreservesFinalAssistantReplyAfterMessageToolOutput(t *testing.T) { diff --git a/web/frontend/src/api/sessions.ts b/web/frontend/src/api/sessions.ts index edd7d7c27..002a3c5d7 100644 --- a/web/frontend/src/api/sessions.ts +++ b/web/frontend/src/api/sessions.ts @@ -15,6 +15,7 @@ export interface SessionDetail { role: "user" | "assistant" content: string kind?: "normal" | "thought" | "tool_calls" + model_name?: string media?: string[] attachments?: { type?: "image" | "audio" | "video" | "file" diff --git a/web/frontend/src/components/chat/assistant-message.tsx b/web/frontend/src/components/chat/assistant-message.tsx index 3857c6743..c8d0480f0 100644 --- a/web/frontend/src/components/chat/assistant-message.tsx +++ b/web/frontend/src/components/chat/assistant-message.tsx @@ -33,6 +33,7 @@ interface AssistantMessageProps { content: string attachments?: ChatAttachment[] kind?: AssistantMessageKind + modelName?: string toolCalls?: ChatToolCall[] timestamp?: string | number } @@ -41,6 +42,7 @@ export function AssistantMessage({ content, attachments = [], kind = "normal", + modelName, toolCalls = [], timestamp = "", }: AssistantMessageProps) { @@ -66,13 +68,20 @@ export function AssistantMessage({ const copyMessageLabel = isCopied ? t("chat.copiedLabel") : t("chat.copyMessage") + const trimmedModelName = modelName?.trim() ?? "" return (
{!isCollapsedBlock && ( -
+
PicoClaw + {trimmedModelName && ( + <> + + {trimmedModelName} + + )} {formattedTimestamp && ( <> @@ -104,6 +113,9 @@ export function AssistantMessage({ )} {collapsedLabel} + {trimmedModelName && ( + {trimmedModelName} + )}
diff --git a/web/frontend/src/features/chat/history.ts b/web/frontend/src/features/chat/history.ts index 9fc35bc1e..72e9ff332 100644 --- a/web/frontend/src/features/chat/history.ts +++ b/web/frontend/src/features/chat/history.ts @@ -50,6 +50,7 @@ export async function loadSessionMessages( role: message.role, content: message.content, kind: message.role === "assistant" ? (message.kind ?? "normal") : undefined, + modelName: message.model_name, toolCalls: message.role === "assistant" ? parseToolCallsValue(message.tool_calls) @@ -86,7 +87,7 @@ function messageSignature(message: ChatMessage): string { return `${message.role}\u0000${message.content}\u0000${normalizeMessageTimestamp( message.timestamp, - )}\u0000${message.kind ?? ""}\u0000${attachmentSignature}\u0000${toolCallsSignature( + )}\u0000${message.kind ?? ""}\u0000${message.modelName ?? ""}\u0000${attachmentSignature}\u0000${toolCallsSignature( message.toolCalls, )}` } diff --git a/web/frontend/src/features/chat/protocol.ts b/web/frontend/src/features/chat/protocol.ts index 04cc924be..b372b235c 100644 --- a/web/frontend/src/features/chat/protocol.ts +++ b/web/frontend/src/features/chat/protocol.ts @@ -83,6 +83,14 @@ function parseContextUsage( } } +function parseModelName(payload: Record): string | undefined { + if (typeof payload.model_name !== "string") { + return undefined + } + const modelName = payload.model_name.trim() + return modelName || undefined +} + export function handlePicoMessage( message: PicoMessage, expectedSessionId: string, @@ -102,6 +110,7 @@ export function handlePicoMessage( const attachments = parseAttachments(payload) const contextUsage = parseContextUsage(payload) const isPlaceholder = payload.placeholder === true + const modelName = parseModelName(payload) const timestamp = message.timestamp !== undefined && Number.isFinite(Number(message.timestamp)) @@ -116,6 +125,7 @@ export function handlePicoMessage( role: "assistant", content, kind, + ...(modelName ? { modelName } : {}), ...(toolCalls ? { toolCalls } : {}), attachments, timestamp, @@ -135,6 +145,7 @@ export function handlePicoMessage( const messageId = payload.message_id as string const attachments = parseAttachments(payload) const contextUsage = parseContextUsage(payload) + const modelName = parseModelName(payload) const timestamp = message.timestamp !== undefined && Number.isFinite(Number(message.timestamp)) @@ -160,6 +171,7 @@ export function handlePicoMessage( content, kind, toolCalls, + ...(modelName ? { modelName } : {}), ...(attachments ? { attachments } : {}), } }) @@ -178,6 +190,7 @@ export function handlePicoMessage( content, kind, toolCalls, + ...(modelName ? { modelName } : {}), ...(attachments ? { attachments } : {}), timestamp, }, diff --git a/web/frontend/src/store/chat.ts b/web/frontend/src/store/chat.ts index aa5f08490..056d5601b 100644 --- a/web/frontend/src/store/chat.ts +++ b/web/frontend/src/store/chat.ts @@ -44,6 +44,7 @@ export interface ChatMessage { content: string timestamp: number | string kind?: AssistantMessageKind + modelName?: string attachments?: ChatAttachment[] toolCalls?: ChatToolCall[] }