feat(chat,seahorse): persist and display model_name across history (#2897)

* feat(chat,seahorse): persist and display model_name across history

* test(seahorse): fix lint regressions in repair coverage

* fix(pico): preserve model_name in live updates

* fix(pico): preserve model_name through live stream wrappers
This commit is contained in:
LC
2026-05-20 13:42:21 +08:00
committed by GitHub
parent 548dc15acd
commit b7db059544
41 changed files with 1266 additions and 139 deletions
+6
View File
@@ -600,6 +600,12 @@ func (al *AgentLoop) runAgentLoop(
Content: result.finalContent,
ContextUsage: computeContextUsage(agent, opts.Dispatch.SessionKey),
}
if modelName := strings.TrimSpace(result.modelName); modelName != "" {
if msg.Context.Raw == nil {
msg.Context.Raw = make(map[string]string, 1)
}
msg.Context.Raw["model_name"] = modelName
}
markFinalOutbound(&msg)
al.bus.PublishOutbound(ctx, msg)
}
+32 -11
View File
@@ -102,7 +102,10 @@ func (al *AgentLoop) targetReasoningChannelID(channelName string) (chatID string
return ""
}
func (al *AgentLoop) publishPicoReasoning(ctx context.Context, reasoningContent, chatID, sessionKey string) {
func (al *AgentLoop) publishPicoReasoning(
ctx context.Context,
reasoningContent, chatID, sessionKey, modelName string,
) {
if reasoningContent == "" || chatID == "" {
return
}
@@ -114,13 +117,16 @@ func (al *AgentLoop) publishPicoReasoning(ctx context.Context, reasoningContent,
pubCtx, pubCancel := context.WithTimeout(ctx, 5*time.Second)
defer pubCancel()
raw := map[string]string{metadataKeyMessageKind: messageKindThought}
if trimmedModelName := strings.TrimSpace(modelName); trimmedModelName != "" {
raw["model_name"] = trimmedModelName
}
if err := al.bus.PublishOutbound(pubCtx, bus.OutboundMessage{
Context: bus.InboundContext{
Channel: "pico",
ChatID: chatID,
Raw: map[string]string{
metadataKeyMessageKind: messageKindThought,
},
Raw: raw,
},
SessionKey: sessionKey,
Content: reasoningContent,
@@ -143,6 +149,7 @@ func (al *AgentLoop) publishPicoReasoning(ctx context.Context, reasoningContent,
func (al *AgentLoop) publishPicoToolCallInterim(
ctx context.Context,
ts *turnState,
modelName string,
reasoningContent string,
content string,
toolCalls []providers.ToolCall,
@@ -155,7 +162,14 @@ func (al *AgentLoop) publishPicoToolCallInterim(
pubCtx, pubCancel := context.WithTimeout(ctx, 3*time.Second)
err := al.bus.PublishOutbound(
pubCtx,
outboundMessageForTurnWithKind(ts, reasoningContent, messageKindThought),
outboundMessageForTurnWithOptions(
ts,
reasoningContent,
outboundTurnMessageOptions{
kind: messageKindThought,
modelName: modelName,
},
),
)
pubCancel()
if err != nil && !errors.Is(err, context.DeadlineExceeded) &&
@@ -182,7 +196,12 @@ func (al *AgentLoop) publishPicoToolCallInterim(
if strings.TrimSpace(content) != "" && !duplicateToolCallContent {
pubCtx, pubCancel := context.WithTimeout(ctx, 3*time.Second)
err := al.bus.PublishOutbound(pubCtx, outboundMessageForTurn(ts, content))
err := al.bus.PublishOutbound(
pubCtx,
outboundMessageForTurnWithOptions(ts, content, outboundTurnMessageOptions{
modelName: modelName,
}),
)
pubCancel()
if err != nil && !errors.Is(err, context.DeadlineExceeded) &&
!errors.Is(err, context.Canceled) &&
@@ -209,11 +228,13 @@ func (al *AgentLoop) publishPicoToolCallInterim(
return
}
msg := outboundMessageForTurnWithKind(ts, "", messageKindToolCalls)
if msg.Context.Raw == nil {
msg.Context.Raw = map[string]string{}
}
msg.Context.Raw[metadataKeyToolCalls] = string(rawToolCalls)
msg := outboundMessageForTurnWithOptions(ts, "", outboundTurnMessageOptions{
kind: messageKindToolCalls,
modelName: modelName,
raw: map[string]string{
metadataKeyToolCalls: string(rawToolCalls),
},
})
pubCtx, pubCancel := context.WithTimeout(ctx, 3*time.Second)
err = al.bus.PublishOutbound(pubCtx, msg)
+1 -1
View File
@@ -312,7 +312,7 @@ func TestPublishPicoReasoningIncludesSessionKey(t *testing.T) {
defer cleanup()
_ = provider
al.publishPicoReasoning(context.Background(), "reasoning", "pico-chat", "session-1")
al.publishPicoReasoning(context.Background(), "reasoning", "pico-chat", "session-1", "")
select {
case outbound := <-msgBus.OutboundChan():
+46 -6
View File
@@ -96,15 +96,46 @@ func markFinalOutbound(msg *bus.OutboundMessage) {
msg.Context.Raw[metadataKeyOutboundKind] = outboundKindFinal
}
func outboundMessageForTurnWithKind(ts *turnState, content, kind string) bus.OutboundMessage {
type outboundTurnMessageOptions struct {
kind string
modelName string
raw map[string]string
}
func outboundMessageForTurnWithOptions(
ts *turnState,
content string,
opts outboundTurnMessageOptions,
) bus.OutboundMessage {
msg := outboundMessageForTurn(ts, content)
if strings.TrimSpace(kind) == "" {
trimmedKind := strings.TrimSpace(opts.kind)
trimmedModelName := strings.TrimSpace(opts.modelName)
rawCount := len(opts.raw)
if trimmedKind != "" {
rawCount++
}
if trimmedModelName != "" {
rawCount++
}
if rawCount == 0 {
return msg
}
if msg.Context.Raw == nil {
msg.Context.Raw = make(map[string]string, 1)
msg.Context.Raw = make(map[string]string, rawCount)
}
if trimmedKind != "" {
msg.Context.Raw[metadataKeyMessageKind] = trimmedKind
}
if trimmedModelName != "" {
msg.Context.Raw["model_name"] = trimmedModelName
}
for key, value := range opts.raw {
if strings.TrimSpace(key) == "" {
continue
}
msg.Context.Raw[key] = value
}
msg.Context.Raw[metadataKeyMessageKind] = kind
return msg
}
@@ -521,8 +552,9 @@ func hasMediaRefs(messages []providers.Message) bool {
func sideQuestionModelName(agent *AgentInstance, usedLight bool) string {
if usedLight && len(agent.LightCandidates) > 0 {
// Use the first light candidate's model
return agent.LightCandidates[0].Model
if name := resolvedCandidateModelName(agent.LightCandidates, ""); name != "" {
return name
}
}
return agent.Model
}
@@ -538,6 +570,14 @@ func modelNameFromIdentityKey(identityKey string) string {
return identityKey
}
func modelAliasFromCandidateIdentityKey(identityKey string) string {
const prefix = "model_name:"
if !strings.HasPrefix(identityKey, prefix) {
return ""
}
return strings.TrimSpace(strings.TrimPrefix(identityKey, prefix))
}
func closeProviderIfStateful(provider providers.LLMProvider) {
if stateful, ok := provider.(providers.StatefulProvider); ok {
stateful.Close()
+2
View File
@@ -197,6 +197,7 @@ func providerToSeahorseMessage(msg protocoltypes.Message) seahorse.Message {
result := seahorse.Message{
Role: msg.Role,
Content: msg.Content,
ModelName: msg.ModelName,
ReasoningContent: msg.ReasoningContent,
TokenCount: tokenizer.EstimateMessageTokens(msg),
}
@@ -243,6 +244,7 @@ func seahorseToProviderMessages(result *seahorse.AssembleResult) []protocoltypes
pm := protocoltypes.Message{
Role: msg.Role,
Content: msg.Content,
ModelName: msg.ModelName,
ReasoningContent: msg.ReasoningContent,
}
+8
View File
@@ -174,6 +174,7 @@ func TestProviderToSeahorseMessageWithReasoning(t *testing.T) {
msg := protocoltypes.Message{
Role: "assistant",
Content: "response text",
ModelName: "gpt-5.4-mini",
ReasoningContent: "I thought about this carefully",
}
@@ -181,6 +182,9 @@ func TestProviderToSeahorseMessageWithReasoning(t *testing.T) {
if result.ReasoningContent != "I thought about this carefully" {
t.Errorf("ReasoningContent = %q, want 'I thought about this carefully'", result.ReasoningContent)
}
if result.ModelName != "gpt-5.4-mini" {
t.Errorf("ModelName = %q, want %q", result.ModelName, "gpt-5.4-mini")
}
}
func TestSeahorseToProviderMessagesWithReasoning(t *testing.T) {
@@ -189,6 +193,7 @@ func TestSeahorseToProviderMessagesWithReasoning(t *testing.T) {
{
Role: "assistant",
Content: "response",
ModelName: "gpt-5.4",
ReasoningContent: "thinking process",
},
},
@@ -201,6 +206,9 @@ func TestSeahorseToProviderMessagesWithReasoning(t *testing.T) {
if messages[0].ReasoningContent != "thinking process" {
t.Errorf("ReasoningContent = %q, want 'thinking process'", messages[0].ReasoningContent)
}
if messages[0].ModelName != "gpt-5.4" {
t.Errorf("ModelName = %q, want %q", messages[0].ModelName, "gpt-5.4")
}
}
func TestSeahorseToProviderMessages(t *testing.T) {
+16 -2
View File
@@ -75,6 +75,7 @@ func candidateFromModelConfig(
return providers.FallbackCandidate{
Provider: protocol,
Model: modelID,
DisplayName: strings.TrimSpace(mc.ModelName),
RPM: mc.RPM,
IdentityKey: modelConfigIdentityKey(mc),
}, true
@@ -147,8 +148,9 @@ func resolveModelCandidate(
}
return providers.FallbackCandidate{
Provider: ref.Provider,
Model: ref.Model,
Provider: ref.Provider,
Model: ref.Model,
DisplayName: raw,
}, true
}
@@ -197,6 +199,18 @@ func resolvedCandidateProvider(candidates []providers.FallbackCandidate, fallbac
return fallback
}
func resolvedCandidateModelName(candidates []providers.FallbackCandidate, fallback string) string {
if len(candidates) > 0 {
if name := modelAliasFromCandidateIdentityKey(candidates[0].IdentityKey); strings.TrimSpace(name) != "" {
return name
}
if displayName := strings.TrimSpace(candidates[0].DisplayName); displayName != "" {
return displayName
}
}
return strings.TrimSpace(fallback)
}
func resolvedModelConfig(cfg *config.Config, modelName, workspace string) (*config.ModelConfig, error) {
if cfg == nil {
return nil, fmt.Errorf("config is nil")
+49
View File
@@ -7,6 +7,55 @@ import (
"github.com/sipeed/picoclaw/pkg/providers"
)
func TestModelNameFromIdentityKey_LegacyProviderModel(t *testing.T) {
if got := modelNameFromIdentityKey("openai/gpt-5.4"); got != "gpt-5.4" {
t.Fatalf("modelNameFromIdentityKey() = %q, want %q", got, "gpt-5.4")
}
}
func TestModelNameFromIdentityKey_PreservesNonLegacyIdentity(t *testing.T) {
if got := modelNameFromIdentityKey("model_name:primary"); got != "model_name:primary" {
t.Fatalf("modelNameFromIdentityKey() = %q, want %q", got, "model_name:primary")
}
}
func TestModelAliasFromCandidateIdentityKey(t *testing.T) {
if got := modelAliasFromCandidateIdentityKey("model_name:primary"); got != "primary" {
t.Fatalf("modelAliasFromCandidateIdentityKey() = %q, want %q", got, "primary")
}
if got := modelAliasFromCandidateIdentityKey("openai/gpt-5.4"); got != "" {
t.Fatalf("modelAliasFromCandidateIdentityKey() = %q, want empty", got)
}
}
func TestResolvedCandidateModelName_PrefersIdentityAlias(t *testing.T) {
got := resolvedCandidateModelName([]providers.FallbackCandidate{
{Provider: "openai", Model: "gpt-5.4", IdentityKey: "model_name:primary"},
}, "fallback-model")
if got != "primary" {
t.Fatalf("resolvedCandidateModelName() = %q, want %q", got, "primary")
}
}
func TestResolvedCandidateModelName_DoesNotScanFallbackAliases(t *testing.T) {
got := resolvedCandidateModelName([]providers.FallbackCandidate{
{Provider: "openai", Model: "gpt-5.4"},
{Provider: "openai", Model: "gpt-5.4-mini", IdentityKey: "model_name:fallback"},
}, "primary-model")
if got != "primary-model" {
t.Fatalf("resolvedCandidateModelName() = %q, want %q", got, "primary-model")
}
}
func TestResolvedCandidateModelName_UsesCandidateDisplayName(t *testing.T) {
got := resolvedCandidateModelName([]providers.FallbackCandidate{
{Provider: "openai", Model: "gpt-5.4", DisplayName: "gpt-5.4-display"},
}, "fallback-model")
if got != "gpt-5.4-display" {
t.Fatalf("resolvedCandidateModelName() = %q, want %q", got, "gpt-5.4-display")
}
}
func TestResolveActiveModelConfig_PrefersCandidateIdentityKey(t *testing.T) {
cfg := &config.Config{
ModelList: []*config.ModelConfig{
+10 -2
View File
@@ -180,7 +180,11 @@ toolLoop:
toolFeedbackArgsPreview(toolArgs, toolFeedbackMaxLen),
)
fbCtx, fbCancel := context.WithTimeout(turnCtx, 3*time.Second)
_ = al.bus.PublishOutbound(fbCtx, outboundMessageForTurnWithKind(ts, feedbackMsg, messageKindToolFeedback))
_ = al.bus.PublishOutbound(fbCtx, outboundMessageForTurnWithOptions(
ts,
feedbackMsg,
outboundTurnMessageOptions{kind: messageKindToolFeedback},
))
fbCancel()
}
@@ -467,7 +471,11 @@ toolLoop:
toolFeedbackArgsPreview(toolArgs, toolFeedbackMaxLen),
)
fbCtx, fbCancel := context.WithTimeout(turnCtx, 3*time.Second)
_ = al.bus.PublishOutbound(fbCtx, outboundMessageForTurnWithKind(ts, feedbackMsg, messageKindToolFeedback))
_ = al.bus.PublishOutbound(fbCtx, outboundMessageForTurnWithOptions(
ts,
feedbackMsg,
outboundTurnMessageOptions{kind: messageKindToolFeedback},
))
fbCancel()
}
+7 -18
View File
@@ -33,6 +33,7 @@ func (p *Pipeline) Finalize(
ts.setPhase(TurnPhaseCompleted)
return turnResult{
finalContent: finalContent,
modelName: exec.llmModelName,
status: turnStatus,
followUps: append([]bus.InboundMessage(nil), ts.followUps...),
}, nil
@@ -44,6 +45,7 @@ func (p *Pipeline) Finalize(
finalMsg := providers.Message{
Role: "assistant",
Content: finalContent,
ModelName: exec.llmModelName,
ReasoningContent: responseReasoningContent(exec.response),
}
ts.agent.Sessions.AddFullMessage(ts.sessionKey, finalMsg)
@@ -80,24 +82,10 @@ func (p *Pipeline) Finalize(
// so the final answer is still delivered outside normal SendResponse.
if ((streamErr != nil && !isConfiguredStreamingVisibleError(streamErr)) || exec.streamingFallback) &&
!ts.opts.SendResponse && ts.opts.AllowInterimPicoPublish && finalContent != "" {
agentID, sessionKey, scope := outboundTurnMetadata(
ts.agent.ID,
ts.opts.Dispatch.SessionKey,
ts.opts.Dispatch.SessionScope,
)
msg := bus.OutboundMessage{
Context: outboundContextFromInbound(
ts.opts.Dispatch.InboundContext,
ts.opts.Dispatch.Channel(),
ts.opts.Dispatch.ChatID(),
ts.opts.Dispatch.ReplyToMessageID(),
),
AgentID: agentID,
SessionKey: sessionKey,
Scope: scope,
Content: finalContent,
ContextUsage: contextUsage,
}
msg := outboundMessageForTurnWithOptions(ts, finalContent, outboundTurnMessageOptions{
modelName: exec.llmModelName,
})
msg.ContextUsage = contextUsage
markFinalOutbound(&msg)
_ = al.bus.PublishOutbound(turnCtx, msg)
}
@@ -112,6 +100,7 @@ func (p *Pipeline) Finalize(
ts.setPhase(TurnPhaseCompleted)
return turnResult{
finalContent: finalContent,
modelName: exec.llmModelName,
status: turnStatus,
followUps: append([]bus.InboundMessage(nil), ts.followUps...),
}, nil
+13 -1
View File
@@ -200,6 +200,16 @@ func (p *Pipeline) CallLLM(
map[string]any{"agent_id": ts.agent.ID, "iteration": iteration},
)
}
for _, candidate := range exec.activeCandidates {
if candidate.StableKey() != fbResult.IdentityKey {
continue
}
exec.llmModelName = resolvedCandidateModelName(
[]providers.FallbackCandidate{candidate},
exec.llmModelName,
)
break
}
return fbResult.Response, nil
}
return exec.activeProvider.Chat(providerCtx, messagesForCall, toolDefsForCall, exec.llmModel, exec.llmOpts)
@@ -477,7 +487,7 @@ func (p *Pipeline) CallLLM(
// Publish pico thoughts before the turn context is canceled at return time.
// The async variant can race with turn teardown and intermittently drop the
// thought message in CI even though the LLM produced reasoning content.
al.publishPicoReasoning(turnCtx, reasoningContent, ts.chatID, ts.sessionKey)
al.publishPicoReasoning(turnCtx, reasoningContent, ts.chatID, ts.sessionKey, exec.llmModelName)
}
} else {
go al.handleReasoning(
@@ -564,6 +574,7 @@ func (p *Pipeline) CallLLM(
assistantMsg := providers.Message{
Role: "assistant",
Content: exec.response.Content,
ModelName: exec.llmModelName,
ReasoningContent: reasoningContent,
}
for _, tc := range exec.normalizedToolCalls {
@@ -607,6 +618,7 @@ func (p *Pipeline) CallLLM(
al.publishPicoToolCallInterim(
turnCtx,
ts,
exec.llmModelName,
reasoningContent,
exec.response.Content,
assistantMsg.ToolCalls,
+6
View File
@@ -89,6 +89,11 @@ func (p *Pipeline) SetupTurn(ctx context.Context, ts *turnState) (*turnExecution
if usedLight && ts.agent.LightProvider != nil {
activeProvider = ts.agent.LightProvider
}
activeModelName := strings.TrimSpace(ts.agent.Model)
if usedLight {
activeModelName = strings.TrimSpace(sideQuestionModelName(ts.agent, true))
}
activeModelName = resolvedCandidateModelName(activeCandidates, activeModelName)
exec := newTurnExecution(
ts.agent,
@@ -106,6 +111,7 @@ func (p *Pipeline) SetupTurn(ctx context.Context, ts *turnState) (*turnExecution
activeModel,
p.Cfg.Agents.Defaults.Provider,
)
exec.llmModelName = activeModelName
exec.activeProvider = activeProvider
exec.usedLight = usedLight
+14 -3
View File
@@ -50,9 +50,10 @@ func (p *Pipeline) tryConfiguredStreamingLLM(
}
publisher := &streamingChunkPublisher{
streamer: streamer,
channel: ts.channel,
chatID: ts.chatID,
streamer: streamer,
channel: ts.channel,
chatID: ts.chatID,
modelName: exec.llmModelName,
}
logger.DebugCF("agent", "configured streaming enabled", map[string]any{
@@ -371,6 +372,7 @@ type streamingChunkPublisher struct {
streamer bus.Streamer
channel string
chatID string
modelName string
published bool
reasoningPublished bool
err error
@@ -380,6 +382,9 @@ func (p *streamingChunkPublisher) Update(ctx context.Context, accumulated string
if p == nil || p.streamer == nil || strings.TrimSpace(accumulated) == "" {
return
}
if setter, ok := p.streamer.(interface{ SetModelName(modelName string) }); ok {
setter.SetModelName(p.modelName)
}
if err := p.streamer.Update(ctx, accumulated); err != nil {
p.err = err
logger.WarnCF("agent", "stream update failed", map[string]any{
@@ -396,6 +401,9 @@ func (p *streamingChunkPublisher) UpdateReasoning(ctx context.Context, accumulat
if p == nil || p.streamer == nil || strings.TrimSpace(accumulated) == "" {
return
}
if setter, ok := p.streamer.(interface{ SetModelName(modelName string) }); ok {
setter.SetModelName(p.modelName)
}
reasoningStreamer, ok := p.streamer.(bus.ReasoningStreamer)
if !ok {
return
@@ -434,6 +442,9 @@ func (p *streamingChunkPublisher) Finalize(ctx context.Context, content string,
if strings.TrimSpace(content) == "" && !p.published {
return nil
}
if setter, ok := p.streamer.(interface{ SetModelName(modelName string) }); ok {
setter.SetModelName(p.modelName)
}
var err error
if streamer, ok := p.streamer.(bus.ContextUsageStreamer); ok {
err = streamer.FinalizeWithContext(ctx, content, contextUsage)
+3
View File
@@ -570,6 +570,9 @@ func TestConfiguredStreamingFinalFlushFailureBeforeVisibleOutputPublishesFallbac
if outbound.Content != "stream response" {
t.Fatalf("fallback outbound content = %q, want stream response", outbound.Content)
}
if got := outbound.Context.Raw["model_name"]; got != "test-model" {
t.Fatalf("fallback outbound model_name = %q, want %q", got, "test-model")
}
case <-time.After(time.Second):
t.Fatal("expected fallback outbound after invisible final stream flush failure")
}
+180
View File
@@ -10,6 +10,7 @@ import (
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/routing"
"github.com/sipeed/picoclaw/pkg/session"
)
@@ -37,6 +38,39 @@ func (p *simpleConvProvider) GetDefaultModel() string {
return "simple-model"
}
type sequenceProvider struct {
responses []*providers.LLMResponse
errors []error
callCount int
mu sync.Mutex
}
func (p *sequenceProvider) Chat(
ctx context.Context,
messages []providers.Message,
tools []providers.ToolDefinition,
model string,
opts map[string]any,
) (*providers.LLMResponse, error) {
p.mu.Lock()
defer p.mu.Unlock()
idx := p.callCount
p.callCount++
if idx < len(p.errors) && p.errors[idx] != nil {
return nil, p.errors[idx]
}
if idx < len(p.responses) && p.responses[idx] != nil {
return p.responses[idx], nil
}
return &providers.LLMResponse{Content: "ok", FinishReason: "stop"}, nil
}
func (p *sequenceProvider) GetDefaultModel() string {
return "sequence-model"
}
type nativeSearchCaptureProvider struct {
lastOpts map[string]any
}
@@ -271,6 +305,152 @@ func TestPipeline_CallLLM_SimpleResponse(t *testing.T) {
}
}
func TestPipeline_SetupTurn_ModelNameDoesNotUseFallbackAliasBeforeFallback(t *testing.T) {
al, agent, cleanup := newTurnCoordTestLoop(t, &simpleConvProvider{})
defer cleanup()
agent.Model = "primary-model"
agent.Candidates = []providers.FallbackCandidate{
{Provider: "openai", Model: "gpt-5.4"},
{Provider: "anthropic", Model: "claude-sonnet", IdentityKey: "model_name:fallback-model"},
}
pipeline := NewPipeline(al)
ts := newTurnState(agent, makeTestProcessOpts("test-session"), turnEventScope{
turnID: "turn-1",
context: newTurnContext(nil, nil, nil),
})
exec, err := pipeline.SetupTurn(context.Background(), ts)
if err != nil {
t.Fatalf("SetupTurn failed: %v", err)
}
if exec.llmModelName != "primary-model" {
t.Fatalf("exec.llmModelName = %q, want %q", exec.llmModelName, "primary-model")
}
}
func TestPipeline_CallLLM_UsesSuccessfulFallbackIdentityAlias(t *testing.T) {
provider := &sequenceProvider{
errors: []error{
errors.New("status: 429 - rate limit exceeded"),
nil,
},
responses: []*providers.LLMResponse{
nil,
{Content: "fallback answer", FinishReason: "stop"},
},
}
al, agent, cleanup := newTurnCoordTestLoop(t, provider)
defer cleanup()
agent.Model = "primary-model"
agent.Candidates = []providers.FallbackCandidate{
{Provider: "openai", Model: "gpt-5.4", IdentityKey: "model_name:primary"},
{Provider: "openai", Model: "gpt-5.4", IdentityKey: "model_name:secondary"},
}
al.fallback = providers.NewFallbackChain(providers.NewCooldownTracker(), nil)
pipeline := NewPipeline(al)
ts := newTurnState(agent, makeTestProcessOpts("test-session"), turnEventScope{
turnID: "turn-1",
context: newTurnContext(nil, nil, nil),
})
exec, err := pipeline.SetupTurn(context.Background(), ts)
if err != nil {
t.Fatalf("SetupTurn failed: %v", err)
}
ctrl, err := pipeline.CallLLM(context.Background(), context.Background(), ts, exec, 1)
if err != nil {
t.Fatalf("CallLLM failed: %v", err)
}
if ctrl != ControlBreak {
t.Fatalf("expected ControlBreak, got %v", ctrl)
}
if exec.llmModelName != "secondary" {
t.Fatalf("exec.llmModelName = %q, want %q", exec.llmModelName, "secondary")
}
}
func TestPipeline_CallLLM_UsesSuccessfulFallbackDisplayNameWithoutAlias(t *testing.T) {
provider := &sequenceProvider{
errors: []error{
errors.New("status: 429 - rate limit exceeded"),
nil,
},
responses: []*providers.LLMResponse{
nil,
{Content: "fallback answer", FinishReason: "stop"},
},
}
al, agent, cleanup := newTurnCoordTestLoop(t, provider)
defer cleanup()
agent.Model = "primary-model"
agent.Candidates = []providers.FallbackCandidate{
{Provider: "openai", Model: "gpt-5.4", IdentityKey: "model_name:primary", DisplayName: "primary-model"},
{Provider: "anthropic", Model: "claude-sonnet", DisplayName: "anthropic/claude-sonnet"},
}
al.fallback = providers.NewFallbackChain(providers.NewCooldownTracker(), nil)
pipeline := NewPipeline(al)
ts := newTurnState(agent, makeTestProcessOpts("test-session"), turnEventScope{
turnID: "turn-1",
context: newTurnContext(nil, nil, nil),
})
exec, err := pipeline.SetupTurn(context.Background(), ts)
if err != nil {
t.Fatalf("SetupTurn failed: %v", err)
}
ctrl, err := pipeline.CallLLM(context.Background(), context.Background(), ts, exec, 1)
if err != nil {
t.Fatalf("CallLLM failed: %v", err)
}
if ctrl != ControlBreak {
t.Fatalf("expected ControlBreak, got %v", ctrl)
}
if exec.llmModelName != "anthropic/claude-sonnet" {
t.Fatalf("exec.llmModelName = %q, want %q", exec.llmModelName, "anthropic/claude-sonnet")
}
}
func TestPipeline_SetupTurn_UsesLightCandidateDisplayName(t *testing.T) {
al, agent, cleanup := newTurnCoordTestLoop(t, &simpleConvProvider{})
defer cleanup()
agent.Model = "primary-model"
agent.Candidates = []providers.FallbackCandidate{
{Provider: "openai", Model: "gpt-5.4", IdentityKey: "model_name:primary", DisplayName: "primary-model"},
}
agent.LightCandidates = []providers.FallbackCandidate{
{Provider: "openai", Model: "gpt-5.4-mini", IdentityKey: "model_name:light-model", DisplayName: "light-model"},
}
agent.Router = routing.New(routing.RouterConfig{LightModel: "light-model", Threshold: 1})
pipeline := NewPipeline(al)
opts := makeTestProcessOpts("test-session")
opts.UserMessage = ""
ts := newTurnState(agent, opts, turnEventScope{
turnID: "turn-1",
context: newTurnContext(nil, nil, nil),
})
exec, err := pipeline.SetupTurn(context.Background(), ts)
if err != nil {
t.Fatalf("SetupTurn failed: %v", err)
}
if !exec.usedLight {
t.Fatal("expected light routing to be used")
}
if exec.llmModelName != "light-model" {
t.Fatalf("exec.llmModelName = %q, want %q", exec.llmModelName, "light-model")
}
}
func TestRunTurn_FinalizeSaveErrorEmitsErrorTurnEnd(t *testing.T) {
al, agent, cleanup := newTurnCoordTestLoop(t, &simpleConvProvider{})
defer cleanup()
+2
View File
@@ -84,6 +84,7 @@ const (
type turnResult struct {
finalContent string
modelName string
status TurnEndStatus
followUps []bus.InboundMessage
}
@@ -140,6 +141,7 @@ type turnExecution struct {
callMessages []providers.Message
providerToolDefs []providers.ToolDefinition
llmModel string
llmModelName string
llmOpts map[string]any
gracefulTerminal bool
useNativeSearch bool
+11
View File
@@ -20,6 +20,17 @@ type MessageEditor interface {
EditMessage(ctx context.Context, chatID string, messageID string, content string) error
}
// MessageEditorWithPayload extends MessageEditor for channels that can update
// structured message metadata in addition to plain text content.
type MessageEditorWithPayload interface {
EditMessageWithPayload(
ctx context.Context,
chatID string,
messageID string,
payload map[string]any,
) error
}
// MessageDeleter — channels that can delete a message by ID.
type MessageDeleter interface {
DeleteMessage(ctx context.Context, chatID string, messageID string) error
+63 -2
View File
@@ -191,6 +191,19 @@ func outboundMessageBypassesPlaceholderEdit(msg bus.OutboundMessage) bool {
return strings.EqualFold(kind, "thought") || strings.EqualFold(kind, "tool_calls")
}
func outboundMessageEditPayload(msg bus.OutboundMessage, content string) map[string]any {
payload := map[string]any{
"content": content,
}
if len(msg.Context.Raw) == 0 {
return payload
}
if modelName := strings.TrimSpace(msg.Context.Raw["model_name"]); modelName != "" {
payload["model_name"] = modelName
}
return payload
}
func outboundMediaChannel(msg bus.OutboundMediaMessage) string {
return msg.Context.Channel
}
@@ -394,7 +407,16 @@ func (m *Manager) preSend(ctx context.Context, name string, msg bus.OutboundMess
if deleter, ok := ch.(MessageDeleter); ok {
deleter.DeleteMessage(ctx, chatID, entry.id) // best effort
} else if editor, ok := ch.(MessageEditor); ok {
editor.EditMessage(ctx, chatID, entry.id, msg.Content) // fallback
if payloadEditor, ok := ch.(MessageEditorWithPayload); ok {
_ = payloadEditor.EditMessageWithPayload(
ctx,
chatID,
entry.id,
outboundMessageEditPayload(msg, msg.Content),
)
} else {
editor.EditMessage(ctx, chatID, entry.id, msg.Content) // fallback
}
}
}
}
@@ -446,7 +468,18 @@ func (m *Manager) preSend(ctx context.Context, name string, msg bus.OutboundMess
trackedContent = prepareToolFeedbackMessageContent(ch, msg.Content)
content = InitialAnimatedToolFeedbackContent(trackedContent)
}
if err := editor.EditMessage(ctx, chatID, entry.id, content); err == nil {
err := func() error {
if payloadEditor, ok := ch.(MessageEditorWithPayload); ok {
return payloadEditor.EditMessageWithPayload(
ctx,
chatID,
entry.id,
outboundMessageEditPayload(msg, content),
)
}
return editor.EditMessage(ctx, chatID, entry.id, content)
}()
if err == nil {
trackedChatID := trackedToolFeedbackMessageChatID(ch, chatID, &msg.Context)
if tracker, ok := ch.(toolFeedbackMessageTracker); ok && isToolFeedback {
tracker.RecordToolFeedbackMessage(trackedChatID, entry.id, trackedContent)
@@ -643,6 +676,18 @@ func reasoningStreamerFrom(streamer bus.Streamer) bus.ReasoningStreamer {
return nil
}
type modelNameStreamer interface {
SetModelName(modelName string)
}
func setStreamerModelName(streamer any, modelName string) {
setter, ok := streamer.(modelNameStreamer)
if !ok {
return
}
setter.SetModelName(modelName)
}
// splitMarkerStreamer turns accumulated streaming text containing
// MessageSplitMarker into separate channel stream messages.
type splitMarkerStreamer struct {
@@ -654,6 +699,7 @@ type splitMarkerStreamer struct {
finalized bool
onFinalize func(context.Context, string)
clearMarker func()
modelName string
}
func (s *splitMarkerStreamer) Update(ctx context.Context, content string) error {
@@ -682,6 +728,7 @@ func (s *splitMarkerStreamer) UpdateReasoning(ctx context.Context, content strin
if s.reasoning == nil {
return nil
}
setStreamerModelName(s.reasoning, s.modelName)
return s.reasoning.UpdateReasoning(ctx, content)
}
@@ -691,9 +738,18 @@ func (s *splitMarkerStreamer) FinalizeReasoning(ctx context.Context, content str
if s.reasoning == nil {
return nil
}
setStreamerModelName(s.reasoning, s.modelName)
return s.reasoning.FinalizeReasoning(ctx, content)
}
func (s *splitMarkerStreamer) SetModelName(modelName string) {
s.mu.Lock()
defer s.mu.Unlock()
s.modelName = strings.TrimSpace(modelName)
setStreamerModelName(s.current, s.modelName)
setStreamerModelName(s.reasoning, s.modelName)
}
func (s *splitMarkerStreamer) Cancel(ctx context.Context) {
s.mu.Lock()
defer s.mu.Unlock()
@@ -772,6 +828,7 @@ func (s *splitMarkerStreamer) ensureCurrentLocked(ctx context.Context) error {
return err
}
s.current = streamer
setStreamerModelName(s.current, s.modelName)
return nil
}
@@ -856,6 +913,10 @@ func (s *finalizeHookStreamer) FinalizeReasoning(ctx context.Context, content st
return nil
}
func (s *finalizeHookStreamer) SetModelName(modelName string) {
setStreamerModelName(s.Streamer, strings.TrimSpace(modelName))
}
func (s *finalizeHookStreamer) runFinalizeHook(ctx context.Context, content string) {
if s.onFinalize != nil {
s.onFinalize(ctx, content)
+102
View File
@@ -142,11 +142,21 @@ func (m *mockReasoningStreamer) FinalizeReasoning(_ context.Context, content str
return nil
}
type modelTrackingReasoningStreamer struct {
mockReasoningStreamer
modelNames []string
}
func (m *modelTrackingReasoningStreamer) SetModelName(modelName string) {
m.modelNames = append(m.modelNames, strings.TrimSpace(modelName))
}
type recordingStreamSegment struct {
updates []string
finals []string
finalUsage *bus.ContextUsage
canceledCount int
modelNames []string
}
func (s *recordingStreamSegment) Update(_ context.Context, content string) error {
@@ -168,6 +178,10 @@ func (s *recordingStreamSegment) Cancel(context.Context) {
s.canceledCount++
}
func (s *recordingStreamSegment) SetModelName(modelName string) {
s.modelNames = append(s.modelNames, strings.TrimSpace(modelName))
}
type mockStreamingChannel struct {
mockMessageEditor
streamer Streamer
@@ -2068,6 +2082,42 @@ func TestGetStreamer_PreservesReasoningStreamer(t *testing.T) {
}
}
func TestGetStreamer_PreservesModelNameSetter(t *testing.T) {
m := newTestManager()
inner := &modelTrackingReasoningStreamer{}
ch := &mockStreamingChannel{
streamer: inner,
}
m.channels["test"] = ch
streamer, ok := m.GetStreamer(context.Background(), "test", "123", "")
if !ok {
t.Fatal("expected streamer to be available")
}
setter, ok := streamer.(interface{ SetModelName(modelName string) })
if !ok {
t.Fatal("manager-wrapped streamer should preserve SetModelName")
}
setter.SetModelName("gpt-5.4")
if err := streamer.Update(context.Background(), "hello"); err != nil {
t.Fatalf("Update() error = %v", err)
}
reasoningStreamer, ok := streamer.(bus.ReasoningStreamer)
if !ok {
t.Fatal("manager-wrapped streamer should preserve ReasoningStreamer")
}
setter.SetModelName("gpt-5.4")
if err := reasoningStreamer.UpdateReasoning(context.Background(), "thinking"); err != nil {
t.Fatalf("UpdateReasoning() error = %v", err)
}
if len(inner.modelNames) != 2 {
t.Fatalf("model name calls = %v, want 2 forwarded calls", inner.modelNames)
}
if inner.modelNames[0] != "gpt-5.4" || inner.modelNames[1] != "gpt-5.4" {
t.Fatalf("model name calls = %v, want both forwarded as gpt-5.4", inner.modelNames)
}
}
func TestGetStreamer_SplitOnMarkerStreamsSeparateSegments(t *testing.T) {
m := newTestManager()
m.config = &config.Config{
@@ -2188,6 +2238,58 @@ func TestGetStreamer_SplitOnMarkerKeepsReasoningOnInitialStreamer(t *testing.T)
}
}
func TestGetStreamer_SplitOnMarkerPreservesModelNameSetter(t *testing.T) {
m := newTestManager()
m.config = &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
SplitOnMarker: true,
},
},
}
initial := &modelTrackingReasoningStreamer{}
next := &recordingStreamSegment{}
callCount := 0
ch := &mockStreamingChannel{
beginStreamFn: func(context.Context, string) (Streamer, error) {
callCount++
if callCount == 1 {
return initial, nil
}
return next, nil
},
}
m.channels["test"] = ch
streamer, ok := m.GetStreamer(context.Background(), "test", "123", "")
if !ok {
t.Fatal("expected streamer to be available")
}
setter, ok := streamer.(interface{ SetModelName(modelName string) })
if !ok {
t.Fatal("split streamer should preserve SetModelName")
}
setter.SetModelName("gpt-5.4-mini")
if err := streamer.Update(context.Background(), "hello<|[SPLIT]|>world"); err != nil {
t.Fatalf("Update() error = %v", err)
}
reasoningStreamer, ok := streamer.(bus.ReasoningStreamer)
if !ok {
t.Fatal("split streamer should preserve ReasoningStreamer")
}
if err := reasoningStreamer.UpdateReasoning(context.Background(), "thinking"); err != nil {
t.Fatalf("UpdateReasoning() error = %v", err)
}
if len(initial.modelNames) == 0 || initial.modelNames[0] != "gpt-5.4-mini" {
t.Fatalf("initial model names = %v, want forwarded gpt-5.4-mini", initial.modelNames)
}
if len(next.modelNames) == 0 || next.modelNames[0] != "gpt-5.4-mini" {
t.Fatalf("next model names = %v, want forwarded gpt-5.4-mini", next.modelNames)
}
}
func TestGetStreamer_FinalizeSeparateMessagesClearsTrackedToolFeedback(t *testing.T) {
m := newTestManager()
m.config = &config.Config{
+91 -9
View File
@@ -325,6 +325,9 @@ func (c *PicoChannel) Send(ctx context.Context, msg bus.OutboundMessage) ([]stri
PayloadKeyContent: content,
"message_id": msgID,
}
if modelName := strings.TrimSpace(msg.Context.Raw[PayloadKeyModelName]); modelName != "" {
payload[PayloadKeyModelName] = modelName
}
switch {
case isThought:
payload[PayloadKeyKind] = MessageKindThought
@@ -359,6 +362,15 @@ func (c *PicoChannel) EditMessage(ctx context.Context, chatID string, messageID
return c.editMessage(ctx, chatID, messageID, content, nil)
}
func (c *PicoChannel) EditMessageWithPayload(
ctx context.Context,
chatID string,
messageID string,
payload map[string]any,
) error {
return c.editMessagePayload(ctx, chatID, messageID, payload, nil)
}
// DeleteMessage implements channels.MessageDeleter.
func (c *PicoChannel) DeleteMessage(ctx context.Context, chatID string, messageID string) error {
outMsg := newMessage(TypeMessageDelete, map[string]any{
@@ -419,14 +431,23 @@ func (c *PicoChannel) finalizeTrackedToolFeedbackMessage(
ctx context.Context,
chatID string,
content string,
editFn func(context.Context, string, string, string, *bus.ContextUsage) error,
editFn func(context.Context, string, string, map[string]any, *bus.ContextUsage) error,
payload map[string]any,
contextUsage *bus.ContextUsage,
) ([]string, bool) {
msgID, baseContent, ok := c.takeToolFeedbackMessage(chatID)
if !ok || editFn == nil {
return nil, false
}
if err := editFn(ctx, chatID, msgID, content, contextUsage); err != nil {
if payload == nil {
payload = map[string]any{
PayloadKeyContent: content,
}
}
if _, ok := payload[PayloadKeyContent]; !ok {
payload[PayloadKeyContent] = content
}
if err := editFn(ctx, chatID, msgID, payload, contextUsage); err != nil {
c.RecordToolFeedbackMessage(chatID, msgID, baseContent)
return nil, false
}
@@ -437,7 +458,20 @@ func (c *PicoChannel) FinalizeToolFeedbackMessage(ctx context.Context, msg bus.O
if !outboundMessageFinalizesTrackedToolFeedback(msg) {
return nil, false
}
return c.finalizeTrackedToolFeedbackMessage(ctx, msg.ChatID, msg.Content, c.editMessage, msg.ContextUsage)
payload := map[string]any{
PayloadKeyContent: msg.Content,
}
if modelName := strings.TrimSpace(msg.Context.Raw[PayloadKeyModelName]); modelName != "" {
payload[PayloadKeyModelName] = modelName
}
return c.finalizeTrackedToolFeedbackMessage(
ctx,
msg.ChatID,
msg.Content,
c.editMessagePayload,
payload,
msg.ContextUsage,
)
}
// StartTyping implements channels.TypingCapable.
@@ -496,6 +530,7 @@ func (c *PicoChannel) BeginStream(ctx context.Context, chatID string) (channels.
type picoStreamer struct {
channel *PicoChannel
chatID string
modelName string
messageID string
reasoningID string
throttleInterval time.Duration
@@ -509,6 +544,15 @@ type picoStreamer struct {
mu sync.Mutex
}
func (s *picoStreamer) SetModelName(modelName string) {
if s == nil {
return
}
s.mu.Lock()
defer s.mu.Unlock()
s.modelName = strings.TrimSpace(modelName)
}
func (s *picoStreamer) Update(ctx context.Context, content string) error {
s.mu.Lock()
defer s.mu.Unlock()
@@ -613,13 +657,23 @@ func (s *picoStreamer) sendLocked(ctx context.Context, content string, contextUs
PayloadKeyContent: content,
"message_id": s.messageID,
}
if s.modelName != "" {
payload[PayloadKeyModelName] = s.modelName
}
setContextUsagePayload(payload, contextUsage)
outMsg := newMessage(TypeMessageCreate, payload)
if err := s.channel.broadcastToSession(s.chatID, outMsg); err != nil {
return err
}
} else if content != s.lastContent || contextUsage != nil {
if err := s.channel.editMessage(ctx, s.chatID, s.messageID, content, contextUsage); err != nil {
payload := map[string]any{
PayloadKeyContent: content,
"message_id": s.messageID,
}
if s.modelName != "" {
payload[PayloadKeyModelName] = s.modelName
}
if err := s.channel.editMessagePayload(ctx, s.chatID, s.messageID, payload, contextUsage); err != nil {
return err
}
}
@@ -642,6 +696,9 @@ func (s *picoStreamer) sendReasoningLocked(ctx context.Context, content string)
PayloadKeyKind: MessageKindThought,
PayloadKeyThought: true,
}
if s.modelName != "" {
payload[PayloadKeyModelName] = s.modelName
}
outMsg := newMessage(TypeMessageCreate, payload)
if err := s.channel.broadcastToSession(s.chatID, outMsg); err != nil {
return err
@@ -653,6 +710,9 @@ func (s *picoStreamer) sendReasoningLocked(ctx context.Context, content string)
PayloadKeyKind: MessageKindThought,
PayloadKeyThought: true,
}
if s.modelName != "" {
payload[PayloadKeyModelName] = s.modelName
}
outMsg := newMessage(TypeMessageUpdate, payload)
if err := s.channel.broadcastToSession(s.chatID, outMsg); err != nil {
return err
@@ -744,6 +804,9 @@ func (c *PicoChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessag
"attachments": attachments,
"message_id": msgID,
})
if modelName := strings.TrimSpace(msg.Context.Raw[PayloadKeyModelName]); modelName != "" {
outMsg.Payload[PayloadKeyModelName] = modelName
}
if err := c.broadcastToSession(msg.ChatID, outMsg); err != nil {
return nil, err
@@ -1358,11 +1421,30 @@ func (c *PicoChannel) editMessage(
content string,
contextUsage *bus.ContextUsage,
) error {
payload := map[string]any{
"message_id": messageID,
"content": content,
return c.editMessagePayload(ctx, chatID, messageID, map[string]any{
PayloadKeyContent: content,
}, contextUsage)
}
func (c *PicoChannel) editMessagePayload(
ctx context.Context,
chatID string,
messageID string,
payload map[string]any,
contextUsage *bus.ContextUsage,
) error {
if payload == nil {
payload = map[string]any{}
}
setContextUsagePayload(payload, contextUsage)
outMsg := newMessage(TypeMessageUpdate, payload)
normalized := make(map[string]any, len(payload)+1)
for key, value := range payload {
normalized[key] = value
}
if _, ok := normalized[PayloadKeyContent]; !ok {
normalized[PayloadKeyContent] = ""
}
normalized["message_id"] = messageID
setContextUsagePayload(normalized, contextUsage)
outMsg := newMessage(TypeMessageUpdate, normalized)
return c.broadcastToSession(chatID, outMsg)
}
+90 -4
View File
@@ -46,12 +46,15 @@ func TestFinalizeTrackedToolFeedbackMessage_StopsTrackingBeforeEdit(t *testing.T
context.Background(),
"pico:chat-1",
"final reply",
func(_ context.Context, chatID, messageID, content string, contextUsage *bus.ContextUsage) error {
func(_ context.Context, chatID, messageID string, payload map[string]any, contextUsage *bus.ContextUsage) error {
if _, ok := ch.currentToolFeedbackMessage(chatID); ok {
t.Fatal("expected tracked tool feedback to be stopped before edit")
}
if chatID != "pico:chat-1" || messageID != "msg-1" || content != "final reply" {
t.Fatalf("unexpected edit args: %s %s %s", chatID, messageID, content)
if chatID != "pico:chat-1" || messageID != "msg-1" {
t.Fatalf("unexpected edit args: %s %s", chatID, messageID)
}
if got := payload[PayloadKeyContent]; got != "final reply" {
t.Fatalf("unexpected content payload: %#v", got)
}
if contextUsage != nil {
t.Fatalf("unexpected context usage: %+v", contextUsage)
@@ -59,6 +62,7 @@ func TestFinalizeTrackedToolFeedbackMessage_StopsTrackingBeforeEdit(t *testing.T
return nil
},
nil,
nil,
)
if !handled {
t.Fatal("expected finalizeTrackedToolFeedbackMessage to handle tracked message")
@@ -115,7 +119,8 @@ func TestSend_ThoughtMessageDoesNotFinalizeTrackedToolFeedback(t *testing.T) {
Channel: "pico",
ChatID: "pico:sess-1",
Raw: map[string]string{
"message_kind": MessageKindThought,
"message_kind": MessageKindThought,
PayloadKeyModelName: "gpt-5.4-mini",
},
},
}); err != nil {
@@ -134,6 +139,9 @@ func TestSend_ThoughtMessageDoesNotFinalizeTrackedToolFeedback(t *testing.T) {
if got := payload[PayloadKeyKind]; got != MessageKindThought {
t.Fatalf("thought kind = %#v, want %q", got, MessageKindThought)
}
if got := payload[PayloadKeyModelName]; got != "gpt-5.4-mini" {
t.Fatalf("thought model_name = %#v, want %q", got, "gpt-5.4-mini")
}
if got := payload["message_id"]; got == "msg-progress" || got == nil || got == "" {
t.Fatalf("thought message_id = %#v, want new non-progress id", got)
}
@@ -151,6 +159,9 @@ func TestSend_ThoughtMessageDoesNotFinalizeTrackedToolFeedback(t *testing.T) {
Context: bus.InboundContext{
Channel: "pico",
ChatID: "pico:sess-1",
Raw: map[string]string{
PayloadKeyModelName: "gpt-5.4",
},
},
ContextUsage: &bus.ContextUsage{
UsedTokens: 321,
@@ -174,6 +185,9 @@ func TestSend_ThoughtMessageDoesNotFinalizeTrackedToolFeedback(t *testing.T) {
if got := payload[PayloadKeyContent]; got != "final reply" {
t.Fatalf("final content = %#v, want %q", got, "final reply")
}
if got := payload[PayloadKeyModelName]; got != "gpt-5.4" {
t.Fatalf("final model_name = %#v, want %q", got, "gpt-5.4")
}
rawUsage, ok := payload["context_usage"].(map[string]any)
if !ok {
t.Fatalf("final context_usage = %#v, want map payload", payload["context_usage"])
@@ -193,6 +207,54 @@ func TestSend_ThoughtMessageDoesNotFinalizeTrackedToolFeedback(t *testing.T) {
}
}
func TestSend_ToolCallsMessageIncludesModelName(t *testing.T) {
ch := newTestPicoChannel(t)
if err := ch.Start(context.Background()); err != nil {
t.Fatalf("Start() error = %v", err)
}
defer ch.Stop(context.Background())
clientConn, received, cleanup := newTestPicoWebSocket(t)
defer cleanup()
ch.addConnForTest(&picoConn{id: "conn-1", conn: clientConn, sessionID: "sess-1"})
if _, err := ch.Send(context.Background(), bus.OutboundMessage{
ChatID: "pico:sess-1",
Content: "",
Context: bus.InboundContext{
Channel: "pico",
ChatID: "pico:sess-1",
Raw: map[string]string{
"message_kind": MessageKindToolCalls,
PayloadKeyModelName: "gpt-5.4",
PayloadKeyToolCalls: `[{"id":"call_1","type":"function","function":{"name":"read_file","arguments":"{\"path\":\"README.md\"}"}}]`,
},
},
}); err != nil {
t.Fatalf("Send(tool_calls) error = %v", err)
}
select {
case msg := <-received:
if msg.Type != TypeMessageCreate {
t.Fatalf("tool_calls message type = %q, want %q", msg.Type, TypeMessageCreate)
}
payload := msg.Payload
if got := payload[PayloadKeyKind]; got != MessageKindToolCalls {
t.Fatalf("tool_calls kind = %#v, want %q", got, MessageKindToolCalls)
}
if got := payload[PayloadKeyModelName]; got != "gpt-5.4" {
t.Fatalf("tool_calls model_name = %#v, want %q", got, "gpt-5.4")
}
if _, ok := payload[PayloadKeyToolCalls].([]any); !ok {
t.Fatalf("tool_calls payload = %#v, want parsed array", payload[PayloadKeyToolCalls])
}
case <-time.After(time.Second):
t.Fatal("expected tool_calls message to be delivered")
}
}
func TestSendPlaceholder_EmitsNormalMessageWithoutKind(t *testing.T) {
ch := newTestPicoChannel(t)
ch.bc.Placeholder.Enabled = true
@@ -257,6 +319,9 @@ func TestBeginStream_CreatesAndUpdatesSameMessage(t *testing.T) {
if err != nil {
t.Fatalf("BeginStream() error = %v", err)
}
if setter, ok := streamer.(interface{ SetModelName(modelName string) }); ok {
setter.SetModelName("gpt-5.4")
}
if err := streamer.Update(context.Background(), "hello"); err != nil {
t.Fatalf("Update(first) error = %v", err)
}
@@ -271,6 +336,9 @@ func TestBeginStream_CreatesAndUpdatesSameMessage(t *testing.T) {
if got := first.Payload[PayloadKeyContent]; got != "hello" {
t.Fatalf("first content = %#v, want hello", got)
}
if got := first.Payload[PayloadKeyModelName]; got != "gpt-5.4" {
t.Fatalf("first model_name = %#v, want %q", got, "gpt-5.4")
}
rawStreamer := streamer.(*picoStreamer)
rawStreamer.mu.Lock()
@@ -290,6 +358,9 @@ func TestBeginStream_CreatesAndUpdatesSameMessage(t *testing.T) {
if got := second.Payload[PayloadKeyContent]; got != secondContent {
t.Fatalf("second content = %#v, want %q", got, secondContent)
}
if got := second.Payload[PayloadKeyModelName]; got != "gpt-5.4" {
t.Fatalf("second model_name = %#v, want %q", got, "gpt-5.4")
}
}
func TestBeginStream_DefaultStreamingShowsSmallIncrements(t *testing.T) {
@@ -355,6 +426,9 @@ func TestBeginStream_StreamsReasoningAsThoughtUpdates(t *testing.T) {
if !ok {
t.Fatal("pico stream should support reasoning updates")
}
if setter, ok := streamer.(interface{ SetModelName(modelName string) }); ok {
setter.SetModelName("gpt-5.4-mini")
}
if err := reasoningStreamer.UpdateReasoning(context.Background(), "thinking"); err != nil {
t.Fatalf("UpdateReasoning(first) error = %v", err)
}
@@ -372,6 +446,9 @@ func TestBeginStream_StreamsReasoningAsThoughtUpdates(t *testing.T) {
if got := first.Payload[PayloadKeyContent]; got != "thinking" {
t.Fatalf("first content = %#v, want thinking", got)
}
if got := first.Payload[PayloadKeyModelName]; got != "gpt-5.4-mini" {
t.Fatalf("first model_name = %#v, want %q", got, "gpt-5.4-mini")
}
if err := reasoningStreamer.UpdateReasoning(context.Background(), "thinking more"); err != nil {
t.Fatalf("UpdateReasoning(second) error = %v", err)
@@ -389,6 +466,9 @@ func TestBeginStream_StreamsReasoningAsThoughtUpdates(t *testing.T) {
if got := second.Payload[PayloadKeyContent]; got != "thinking more" {
t.Fatalf("second content = %#v, want thinking more", got)
}
if got := second.Payload[PayloadKeyModelName]; got != "gpt-5.4-mini" {
t.Fatalf("second model_name = %#v, want %q", got, "gpt-5.4-mini")
}
}
func TestBeginStream_ThrottlesIntermediateUpdatesAndFinalFlushes(t *testing.T) {
@@ -473,6 +553,9 @@ func TestBeginStream_FinalizeIncludesContextUsage(t *testing.T) {
if err != nil {
t.Fatalf("BeginStream() error = %v", err)
}
if setter, ok := streamer.(interface{ SetModelName(modelName string) }); ok {
setter.SetModelName("gpt-5.4")
}
if err := streamer.Update(context.Background(), "partial"); err != nil {
t.Fatalf("Update() error = %v", err)
}
@@ -501,6 +584,9 @@ func TestBeginStream_FinalizeIncludesContextUsage(t *testing.T) {
if got := final.Payload["message_id"]; got != msgID {
t.Fatalf("final message_id = %#v, want %q", got, msgID)
}
if got := final.Payload[PayloadKeyModelName]; got != "gpt-5.4" {
t.Fatalf("final model_name = %#v, want %q", got, "gpt-5.4")
}
rawUsage, ok := final.Payload["context_usage"].(map[string]any)
if !ok {
t.Fatalf("final context_usage = %#v, want map", final.Payload["context_usage"])
+1
View File
@@ -27,6 +27,7 @@ const (
PayloadKeyKind = "kind"
PayloadKeyPlaceholder = "placeholder"
PayloadKeyToolCalls = "tool_calls"
PayloadKeyModelName = "model_name"
MessageKindThought = "thought"
MessageKindToolCalls = "tool_calls"
+26
View File
@@ -130,6 +130,32 @@ func TestAddFullMessage_WithToolCalls(t *testing.T) {
}
}
func TestAddFullMessage_PreservesModelName(t *testing.T) {
store := newTestStore(t)
ctx := context.Background()
msg := providers.Message{
Role: "assistant",
Content: "done",
ModelName: "gpt-5.4-mini",
}
if err := store.AddFullMessage(ctx, "model-name", msg); err != nil {
t.Fatalf("AddFullMessage: %v", err)
}
history, err := store.GetHistory(ctx, "model-name")
if err != nil {
t.Fatalf("GetHistory: %v", err)
}
if len(history) != 1 {
t.Fatalf("expected 1, got %d", len(history))
}
if history[0].ModelName != "gpt-5.4-mini" {
t.Fatalf("ModelName = %q, want %q", history[0].ModelName, "gpt-5.4-mini")
}
}
func TestAddFullMessage_ToolCallID(t *testing.T) {
store := newTestStore(t)
ctx := context.Background()
+11 -6
View File
@@ -17,6 +17,7 @@ type FallbackChain struct {
type FallbackCandidate struct {
Provider string
Model string
DisplayName string // optional configured alias/raw model label for persistence/UI
RPM int // requests per minute; 0 means unrestricted
IdentityKey string // optional stable config identity for cooldown/rate limiting
}
@@ -32,10 +33,11 @@ func (c FallbackCandidate) StableKey() string {
// FallbackResult contains the successful response and metadata about all attempts.
type FallbackResult struct {
Response *LLMResponse
Provider string
Model string
Attempts []FallbackAttempt
Response *LLMResponse
Provider string
Model string
IdentityKey string
Attempts []FallbackAttempt
}
// FallbackAttempt records one attempt in the fallback chain.
@@ -85,8 +87,9 @@ func ResolveCandidatesWithLookup(
}
seen[key] = true
candidates = append(candidates, FallbackCandidate{
Provider: ref.Provider,
Model: ref.Model,
Provider: ref.Provider,
Model: ref.Model,
DisplayName: candidateRaw,
})
}
@@ -187,6 +190,7 @@ func (fc *FallbackChain) Execute(
result.Response = resp
result.Provider = candidate.Provider
result.Model = candidate.Model
result.IdentityKey = candidate.StableKey()
return result, nil
}
@@ -305,6 +309,7 @@ func (fc *FallbackChain) ExecuteImage(
result.Response = resp
result.Provider = candidate.Provider
result.Model = candidate.Model
result.IdentityKey = candidate.StableKey()
return result, nil
}
+1
View File
@@ -86,6 +86,7 @@ type Attachment struct {
type Message struct {
Role string `json:"role"`
Content string `json:"content"`
ModelName string `json:"model_name,omitempty"`
Media []string `json:"media,omitempty"`
Attachments []Attachment `json:"attachments,omitempty"`
ReasoningContent string `json:"reasoning_content,omitempty"`
+19
View File
@@ -46,6 +46,7 @@ func runSchema(db *sql.DB) error {
conversation_id INTEGER NOT NULL REFERENCES conversations(conversation_id),
role TEXT NOT NULL,
content TEXT NOT NULL DEFAULT '',
model_name TEXT NOT NULL DEFAULT '',
reasoning_content TEXT NOT NULL DEFAULT '',
token_count INTEGER NOT NULL DEFAULT 0,
created_at TEXT NOT NULL DEFAULT (datetime('now'))
@@ -162,6 +163,9 @@ func runSchema(db *sql.DB) error {
if err := ensureMessagesReasoningContentColumn(db); err != nil {
return err
}
if err := ensureMessagesModelNameColumn(db); err != nil {
return err
}
return nil
}
@@ -180,6 +184,21 @@ func ensureMessagesReasoningContentColumn(db *sql.DB) error {
return nil
}
func ensureMessagesModelNameColumn(db *sql.DB) error {
hasColumn, err := tableHasColumn(db, "messages", "model_name")
if err != nil {
return fmt.Errorf("check messages.model_name: %w", err)
}
if hasColumn {
return nil
}
if _, err := db.Exec(`ALTER TABLE messages ADD COLUMN model_name TEXT NOT NULL DEFAULT ''`); err != nil {
return fmt.Errorf("add messages.model_name: %w", err)
}
return nil
}
func tableHasColumn(db *sql.DB, tableName, columnName string) (bool, error) {
rows, err := db.Query(fmt.Sprintf(`PRAGMA table_info(%s)`, tableName))
if err != nil {
+31
View File
@@ -138,6 +138,37 @@ func TestRunSchemaAddsMessagesReasoningContentColumn(t *testing.T) {
}
}
func TestRunSchemaAddsMessagesModelNameColumn(t *testing.T) {
db := openTestDB(t)
_, err := db.Exec(`CREATE TABLE messages (
message_id INTEGER PRIMARY KEY AUTOINCREMENT,
conversation_id INTEGER NOT NULL,
role TEXT NOT NULL,
content TEXT NOT NULL DEFAULT '',
reasoning_content TEXT NOT NULL DEFAULT '',
token_count INTEGER NOT NULL DEFAULT 0,
created_at TEXT NOT NULL DEFAULT (datetime('now'))
)`)
if err != nil {
t.Fatalf("create legacy messages table: %v", err)
}
err = runSchema(db)
if err != nil {
t.Fatalf("runSchema: %v", err)
}
var count int
err = db.QueryRow(`SELECT count(*) FROM pragma_table_info('messages') WHERE name = 'model_name'`).Scan(&count)
if err != nil {
t.Fatalf("query pragma_table_info: %v", err)
}
if count != 1 {
t.Fatalf("model_name column count = %d, want 1", count)
}
}
func TestMigrationConversationUnique(t *testing.T) {
db := openTestDB(t)
if err := runSchema(db); err != nil {
+97 -20
View File
@@ -258,6 +258,7 @@ func (e *Engine) Ingest(ctx context.Context, sessionKey string, messages []Messa
conv.ConversationID,
msg.Role,
msg.Parts,
msg.ModelName,
msg.ReasoningContent,
msg.TokenCount,
)
@@ -267,6 +268,7 @@ func (e *Engine) Ingest(ctx context.Context, sessionKey string, messages []Messa
conv.ConversationID,
msg.Role,
msg.Content,
msg.ModelName,
msg.ReasoningContent,
msg.TokenCount,
)
@@ -431,6 +433,31 @@ func (e *Engine) Bootstrap(ctx context.Context, sessionKey string, messages []Me
return fmt.Errorf("bootstrap: get messages: %w", err)
}
// Migration repair path: old SeaHorse rows may be missing reasoning_content
// even though the canonical JSONL history already has it. Backfill those
// rows in place so we do not treat this as edited history and leave stale
// summaries/context behind after a partial raw-message rebuild.
repairedReasoning, err := e.repairBootstrapReasoningContent(ctx, dbMsgs, messages)
if err != nil {
return fmt.Errorf("bootstrap: repair reasoning_content: %w", err)
}
repairedModelName, err := e.repairBootstrapModelName(ctx, dbMsgs, messages)
if err != nil {
return fmt.Errorf("bootstrap: repair model_name: %w", err)
}
if (repairedReasoning || repairedModelName) && len(dbMsgs) == len(messages) {
matched := true
for i := range messages {
if !messageMatches(dbMsgs[i], messages[i]) {
matched = false
break
}
}
if matched {
return nil
}
}
// Fast path: DB has same count and exact match → no-op
if len(dbMsgs) == len(messages) {
matched := true
@@ -445,16 +472,6 @@ func (e *Engine) Bootstrap(ctx context.Context, sessionKey string, messages []Me
}
}
// Migration repair path: old SeaHorse rows may be missing reasoning_content
// even though the canonical JSONL history already has it. Backfill those
// rows in place so we do not treat this as edited history and leave stale
// summaries/context behind after a partial raw-message rebuild.
if repaired, err := e.repairBootstrapReasoningContent(ctx, dbMsgs, messages); err != nil {
return fmt.Errorf("bootstrap: repair reasoning_content: %w", err)
} else if repaired && len(dbMsgs) == len(messages) {
return nil
}
// Find longest matching prefix from the start
anchor := -1
compareLen := min(len(dbMsgs), len(messages))
@@ -465,14 +482,16 @@ func (e *Engine) Bootstrap(ctx context.Context, sessionKey string, messages []Me
} else {
// Mismatch detected - log details and rebuild
logger.InfoCF("seahorse", "bootstrap: mismatch detected", map[string]any{
"conv_id": conv.ConversationID,
"index": i,
"db_role": dbMsgs[i].Role,
"db_content": truncate(dbMsgs[i].Content, 50),
"db_parts": len(dbMsgs[i].Parts),
"msg_role": messages[i].Role,
"msg_content": truncate(messages[i].Content, 50),
"msg_parts": len(messages[i].Parts),
"conv_id": conv.ConversationID,
"index": i,
"db_role": dbMsgs[i].Role,
"db_content": truncate(dbMsgs[i].Content, 50),
"db_parts": len(dbMsgs[i].Parts),
"db_model_name": dbMsgs[i].ModelName,
"msg_role": messages[i].Role,
"msg_content": truncate(messages[i].Content, 50),
"msg_parts": len(messages[i].Parts),
"msg_model_name": messages[i].ModelName,
})
break
}
@@ -559,7 +578,7 @@ func (e *Engine) repairBootstrapReasoningContent(ctx context.Context, dbMsgs, me
}
for i := range overlap {
if !messageMatchesIgnoringReasoning(dbMsgs[i], messages[i]) {
if !messageMatchesIgnoringReasoningAndModelName(dbMsgs[i], messages[i]) {
return false, nil
}
if dbMsgs[i].ReasoningContent == messages[i].ReasoningContent {
@@ -596,6 +615,57 @@ func (e *Engine) repairBootstrapReasoningContent(ctx context.Context, dbMsgs, me
return true, nil
}
func (e *Engine) repairBootstrapModelName(ctx context.Context, dbMsgs, messages []Message) (bool, error) {
if len(dbMsgs) == 0 || len(messages) == 0 {
return false, nil
}
overlap := min(len(messages), len(dbMsgs))
var updates []struct {
index int
messageID int64
modelName string
}
for i := range overlap {
if !messageMatchesIgnoringReasoningAndModelName(dbMsgs[i], messages[i]) {
return false, nil
}
if dbMsgs[i].ModelName == messages[i].ModelName {
continue
}
if messages[i].ModelName == "" {
return false, nil
}
updates = append(updates, struct {
index int
messageID int64
modelName string
}{
index: i,
messageID: dbMsgs[i].ID,
modelName: messages[i].ModelName,
})
}
if len(updates) == 0 {
return false, nil
}
for _, update := range updates {
if err := e.store.UpdateMessageModelName(ctx, update.messageID, update.modelName); err != nil {
return false, err
}
dbMsgs[update.index].ModelName = update.modelName
}
logger.InfoCF("seahorse", "bootstrap: repaired missing model_name", map[string]any{
"messages": len(updates),
})
return true, nil
}
// truncate shortens a string for logging.
func truncate(s string, maxLen int) string {
if len(s) <= maxLen {
@@ -610,13 +680,20 @@ func truncate(s string, maxLen int) string {
// For messages with Parts (tool_use, tool_result), compare Parts instead of Content
// because structured messages are matched by their parts payload.
func messageMatches(a, b Message) bool {
if a.Role != b.Role || a.ReasoningContent != b.ReasoningContent {
if a.Role != b.Role || a.ReasoningContent != b.ReasoningContent || a.ModelName != b.ModelName {
return false
}
return messageMatchesIgnoringReasoning(a, b)
}
func messageMatchesIgnoringReasoning(a, b Message) bool {
if a.ModelName != b.ModelName {
return false
}
return messageMatchesIgnoringReasoningAndModelName(a, b)
}
func messageMatchesIgnoringReasoningAndModelName(a, b Message) bool {
if a.Role != b.Role {
return false
}
+178 -28
View File
@@ -25,6 +25,43 @@ func newTestEngine(t *testing.T) *Engine {
}
}
func prepareBootstrapRepairConversation(
t *testing.T,
eng *Engine,
ctx context.Context,
sessionKey string,
) (*Conversation, []Message) {
t.Helper()
conv, err := eng.store.GetOrCreateConversation(ctx, sessionKey)
if err != nil {
t.Fatalf("GetOrCreateConversation: %v", err)
}
userMsg, err := eng.store.AddMessage(ctx, conv.ConversationID, "user", "hello", 3)
if err != nil {
t.Fatalf("AddMessage user: %v", err)
}
assistantMsg, err := eng.store.AddMessage(ctx, conv.ConversationID, "assistant", "world", 3)
if err != nil {
t.Fatalf("AddMessage assistant: %v", err)
}
if err := eng.store.AppendContextMessages(
ctx,
conv.ConversationID,
[]int64{userMsg.ID, assistantMsg.ID},
); err != nil {
t.Fatalf("AppendContextMessages: %v", err)
}
return conv, []Message{
{Role: "user", Content: "hello", TokenCount: 3},
{Role: "assistant", Content: "world", TokenCount: 3},
}
}
// --- compileSessionPattern ---
func TestCompileSessionPattern(t *testing.T) {
@@ -328,6 +365,7 @@ func TestEngineIngestPreservesReasoningContent(t *testing.T) {
{
Role: "assistant",
Content: "world",
ModelName: "gpt-5.4-mini",
ReasoningContent: "let me think this through",
TokenCount: 4,
},
@@ -353,6 +391,9 @@ func TestEngineIngestPreservesReasoningContent(t *testing.T) {
"let me think this through",
)
}
if stored[0].ModelName != "gpt-5.4-mini" {
t.Errorf("stored[0].ModelName = %q, want %q", stored[0].ModelName, "gpt-5.4-mini")
}
result, err := eng.Assemble(ctx, "agent:reasoning", AssembleInput{Budget: 1000})
if err != nil {
@@ -368,6 +409,140 @@ func TestEngineIngestPreservesReasoningContent(t *testing.T) {
"let me think this through",
)
}
if result.Messages[0].ModelName != "gpt-5.4-mini" {
t.Errorf("assembled model_name = %q, want %q", result.Messages[0].ModelName, "gpt-5.4-mini")
}
}
func TestBootstrapRepairsMissingModelName(t *testing.T) {
eng := newTestEngine(t)
ctx := context.Background()
sessionKey := "agent:repair-model-name"
conv, msgs := prepareBootstrapRepairConversation(t, eng, ctx, sessionKey)
msgs[1].ModelName = "gpt-5.4"
err := eng.Bootstrap(ctx, sessionKey, msgs)
if err != nil {
t.Fatalf("Bootstrap: %v", err)
}
stored, err := eng.store.GetMessages(ctx, conv.ConversationID, 10, 0)
if err != nil {
t.Fatalf("GetMessages: %v", err)
}
if len(stored) != 2 {
t.Fatalf("stored messages = %d, want 2", len(stored))
}
if stored[1].ModelName != "gpt-5.4" {
t.Fatalf("stored[1].ModelName = %q, want %q", stored[1].ModelName, "gpt-5.4")
}
}
func TestBootstrapRepairsReasoningContentAndModelNameTogether(t *testing.T) {
eng := newTestEngine(t)
ctx := context.Background()
sessionKey := "agent:repair-both-fields"
conv, err := eng.store.GetOrCreateConversation(ctx, sessionKey)
if err != nil {
t.Fatalf("GetOrCreateConversation: %v", err)
}
userMsg, err := eng.store.AddMessage(ctx, conv.ConversationID, "user", "hello", 3)
if err != nil {
t.Fatalf("AddMessage user: %v", err)
}
assistantMsg, err := eng.store.AddMessage(ctx, conv.ConversationID, "assistant", "world", 3)
if err != nil {
t.Fatalf("AddMessage assistant: %v", err)
}
err = eng.store.AppendContextMessages(ctx, conv.ConversationID, []int64{userMsg.ID, assistantMsg.ID})
if err != nil {
t.Fatalf("AppendContextMessages: %v", err)
}
err = eng.Bootstrap(ctx, sessionKey, []Message{
{Role: "user", Content: "hello", TokenCount: 3},
{
Role: "assistant",
Content: "world",
ModelName: "gpt-5.4",
ReasoningContent: "let me think this through",
TokenCount: 3,
},
})
if err != nil {
t.Fatalf("Bootstrap: %v", err)
}
stored, err := eng.store.GetMessages(ctx, conv.ConversationID, 10, 0)
if err != nil {
t.Fatalf("GetMessages: %v", err)
}
if len(stored) != 2 {
t.Fatalf("stored messages = %d, want 2", len(stored))
}
if stored[1].ReasoningContent != "let me think this through" {
t.Fatalf("stored[1].ReasoningContent = %q, want %q", stored[1].ReasoningContent, "let me think this through")
}
if stored[1].ModelName != "gpt-5.4" {
t.Fatalf("stored[1].ModelName = %q, want %q", stored[1].ModelName, "gpt-5.4")
}
}
func TestBootstrapRepairsIncorrectNonEmptyModelName(t *testing.T) {
eng := newTestEngine(t)
ctx := context.Background()
sessionKey := "agent:repair-wrong-model-name"
conv, err := eng.store.GetOrCreateConversation(ctx, sessionKey)
if err != nil {
t.Fatalf("GetOrCreateConversation: %v", err)
}
userMsg, err := eng.store.AddMessage(ctx, conv.ConversationID, "user", "hello", 3)
if err != nil {
t.Fatalf("AddMessage user: %v", err)
}
assistantMsg, err := eng.store.AddMessageWithReasoning(
ctx,
conv.ConversationID,
"assistant",
"world",
"wrong-model",
"",
3,
)
if err != nil {
t.Fatalf("AddMessageWithReasoning assistant: %v", err)
}
err = eng.store.AppendContextMessages(ctx, conv.ConversationID, []int64{userMsg.ID, assistantMsg.ID})
if err != nil {
t.Fatalf("AppendContextMessages: %v", err)
}
err = eng.Bootstrap(ctx, sessionKey, []Message{
{Role: "user", Content: "hello", TokenCount: 3},
{Role: "assistant", Content: "world", ModelName: "gpt-5.4", TokenCount: 3},
})
if err != nil {
t.Fatalf("Bootstrap: %v", err)
}
stored, err := eng.store.GetMessages(ctx, conv.ConversationID, 10, 0)
if err != nil {
t.Fatalf("GetMessages: %v", err)
}
if len(stored) != 2 {
t.Fatalf("stored messages = %d, want 2", len(stored))
}
if stored[1].ModelName != "gpt-5.4" {
t.Fatalf("stored[1].ModelName = %q, want %q", stored[1].ModelName, "gpt-5.4")
}
}
func TestEngineIngestWithPartsPreservesReasoningContent(t *testing.T) {
@@ -620,35 +795,10 @@ func TestBootstrapRepairsMissingReasoningContent(t *testing.T) {
eng := newTestEngine(t)
ctx := context.Background()
sessionKey := "agent:repair-reasoning"
conv, msgs := prepareBootstrapRepairConversation(t, eng, ctx, sessionKey)
msgs[1].ReasoningContent = "let me think this through"
conv, err := eng.store.GetOrCreateConversation(ctx, sessionKey)
if err != nil {
t.Fatalf("GetOrCreateConversation: %v", err)
}
userMsg, err := eng.store.AddMessage(ctx, conv.ConversationID, "user", "hello", 3)
if err != nil {
t.Fatalf("AddMessage user: %v", err)
}
assistantMsg, err := eng.store.AddMessage(ctx, conv.ConversationID, "assistant", "world", 3)
if err != nil {
t.Fatalf("AddMessage assistant: %v", err)
}
err = eng.store.AppendContextMessages(
ctx,
conv.ConversationID,
[]int64{userMsg.ID, assistantMsg.ID},
)
if err != nil {
t.Fatalf("AppendContextMessages: %v", err)
}
err = eng.Bootstrap(ctx, sessionKey, []Message{
{Role: "user", Content: "hello", TokenCount: 3},
{Role: "assistant", Content: "world", ReasoningContent: "let me think this through", TokenCount: 3},
})
err := eng.Bootstrap(ctx, sessionKey, msgs)
if err != nil {
t.Fatalf("Bootstrap: %v", err)
}
+53 -14
View File
@@ -162,19 +162,25 @@ func (s *Store) getMessageTimeRange(ctx context.Context, convID int64) (time.Tim
// AddMessage appends a message to a conversation.
func (s *Store) AddMessage(ctx context.Context, convID int64, role, content string, tokenCount int) (*Message, error) {
return s.AddMessageWithReasoning(ctx, convID, role, content, "", tokenCount)
return s.AddMessageWithReasoning(ctx, convID, role, content, "", "", tokenCount)
}
// AddMessageWithReasoning appends a message with reasoning content to a conversation.
func (s *Store) AddMessageWithReasoning(
ctx context.Context,
convID int64,
role, content, reasoningContent string,
role, content, modelName, reasoningContent string,
tokenCount int,
) (*Message, error) {
result, err := s.db.ExecContext(ctx,
"INSERT INTO messages (conversation_id, role, content, reasoning_content, token_count) VALUES (?, ?, ?, ?, ?)",
convID, role, content, reasoningContent, tokenCount,
result, err := s.db.ExecContext(
ctx,
"INSERT INTO messages (conversation_id, role, content, model_name, reasoning_content, token_count) VALUES (?, ?, ?, ?, ?, ?)",
convID,
role,
content,
modelName,
reasoningContent,
tokenCount,
)
if err != nil {
return nil, fmt.Errorf("add message: %w", err)
@@ -185,6 +191,7 @@ func (s *Store) AddMessageWithReasoning(
ConversationID: convID,
Role: role,
Content: content,
ModelName: modelName,
ReasoningContent: reasoningContent,
TokenCount: tokenCount,
}, nil
@@ -224,7 +231,7 @@ func (s *Store) AddMessageWithParts(
parts []MessagePart,
tokenCount int,
) (*Message, error) {
return s.AddMessageWithPartsAndReasoning(ctx, convID, role, parts, "", tokenCount)
return s.AddMessageWithPartsAndReasoning(ctx, convID, role, parts, "", "", tokenCount)
}
// AddMessageWithPartsAndReasoning adds a message with structured parts and reasoning content.
@@ -233,6 +240,7 @@ func (s *Store) AddMessageWithPartsAndReasoning(
convID int64,
role string,
parts []MessagePart,
modelName string,
reasoningContent string,
tokenCount int,
) (*Message, error) {
@@ -245,9 +253,15 @@ func (s *Store) AddMessageWithPartsAndReasoning(
// Derive readable content from Parts for FTS5 indexing and summary formatting
readableContent := partsToReadableContent(parts)
result, err := tx.ExecContext(ctx,
"INSERT INTO messages (conversation_id, role, content, reasoning_content, token_count) VALUES (?, ?, ?, ?, ?)",
convID, role, readableContent, reasoningContent, tokenCount,
result, err := tx.ExecContext(
ctx,
"INSERT INTO messages (conversation_id, role, content, model_name, reasoning_content, token_count) VALUES (?, ?, ?, ?, ?, ?)",
convID,
role,
readableContent,
modelName,
reasoningContent,
tokenCount,
)
if err != nil {
return nil, fmt.Errorf("add message: %w", err)
@@ -282,6 +296,7 @@ func (s *Store) AddMessageWithPartsAndReasoning(
ID: msgID,
ConversationID: convID,
Role: role,
ModelName: modelName,
ReasoningContent: reasoningContent,
TokenCount: tokenCount,
Parts: make([]MessagePart, len(parts)),
@@ -295,7 +310,7 @@ func (s *Store) AddMessageWithPartsAndReasoning(
// GetMessages retrieves messages for a conversation.
func (s *Store) GetMessages(ctx context.Context, convID int64, limit int, beforeID int64) ([]Message, error) {
query := "SELECT message_id, conversation_id, role, content, reasoning_content, token_count, created_at FROM messages WHERE conversation_id = ?"
query := "SELECT message_id, conversation_id, role, content, model_name, reasoning_content, token_count, created_at FROM messages WHERE conversation_id = ?"
args := []any{convID}
if beforeID > 0 {
query += " AND message_id < ?"
@@ -322,6 +337,7 @@ func (s *Store) GetMessages(ctx context.Context, convID int64, limit int, before
&msg.ConversationID,
&msg.Role,
&msg.Content,
&msg.ModelName,
&msg.ReasoningContent,
&msg.TokenCount,
&createdAt,
@@ -362,9 +378,9 @@ func (s *Store) GetMessageByID(ctx context.Context, messageID int64) (*Message,
var createdAt string
err := s.db.QueryRowContext(
ctx,
"SELECT message_id, conversation_id, role, content, reasoning_content, token_count, created_at FROM messages WHERE message_id = ?",
"SELECT message_id, conversation_id, role, content, model_name, reasoning_content, token_count, created_at FROM messages WHERE message_id = ?",
messageID,
).Scan(&msg.ID, &msg.ConversationID, &msg.Role, &msg.Content, &msg.ReasoningContent, &msg.TokenCount, &createdAt)
).Scan(&msg.ID, &msg.ConversationID, &msg.Role, &msg.Content, &msg.ModelName, &msg.ReasoningContent, &msg.TokenCount, &createdAt)
if err == sql.ErrNoRows {
return nil, fmt.Errorf("message %d not found", messageID)
}
@@ -398,6 +414,27 @@ func (s *Store) UpdateMessageReasoningContent(ctx context.Context, messageID int
return nil
}
func (s *Store) UpdateMessageModelName(ctx context.Context, messageID int64, modelName string) error {
result, err := s.db.ExecContext(
ctx,
"UPDATE messages SET model_name = ? WHERE message_id = ?",
modelName,
messageID,
)
if err != nil {
return fmt.Errorf("update message model_name: %w", err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("update message model_name rows affected: %w", err)
}
if rowsAffected == 0 {
return fmt.Errorf("message %d not found", messageID)
}
return nil
}
func (s *Store) loadMessageParts(ctx context.Context, msgID int64) ([]MessagePart, error) {
rows, err := s.db.QueryContext(ctx,
`SELECT part_id, message_id, type, text, name, arguments, tool_call_id, media_uri, mime_type
@@ -581,8 +618,9 @@ func (s *Store) LinkSummaryToMessages(ctx context.Context, summaryID string, mes
// GetSummarySourceMessages retrieves source messages for a summary.
func (s *Store) GetSummarySourceMessages(ctx context.Context, summaryID string) ([]Message, error) {
rows, err := s.db.QueryContext(ctx,
`SELECT m.message_id, m.conversation_id, m.role, m.content, m.reasoning_content, m.token_count, m.created_at
rows, err := s.db.QueryContext(
ctx,
`SELECT m.message_id, m.conversation_id, m.role, m.content, m.model_name, m.reasoning_content, m.token_count, m.created_at
FROM summary_messages sm
JOIN messages m ON m.message_id = sm.message_id
WHERE sm.summary_id = ?
@@ -603,6 +641,7 @@ func (s *Store) GetSummarySourceMessages(ctx context.Context, summaryID string)
&msg.ConversationID,
&msg.Role,
&msg.Content,
&msg.ModelName,
&msg.ReasoningContent,
&msg.TokenCount,
&createdAt,
+14
View File
@@ -210,6 +210,7 @@ func TestStoreAddAndGetMessagesWithReasoningContent(t *testing.T) {
conv.ConversationID,
"assistant",
"hello world",
"gpt-5.4-mini",
"let me think",
5,
)
@@ -219,6 +220,9 @@ func TestStoreAddAndGetMessagesWithReasoningContent(t *testing.T) {
if msg.ReasoningContent != "let me think" {
t.Fatalf("ReasoningContent = %q, want %q", msg.ReasoningContent, "let me think")
}
if msg.ModelName != "gpt-5.4-mini" {
t.Fatalf("ModelName = %q, want %q", msg.ModelName, "gpt-5.4-mini")
}
msgs, err := s.GetMessages(ctx, conv.ConversationID, 10, 0)
if err != nil {
@@ -230,6 +234,9 @@ func TestStoreAddAndGetMessagesWithReasoningContent(t *testing.T) {
if msgs[0].ReasoningContent != "let me think" {
t.Errorf("ReasoningContent = %q, want %q", msgs[0].ReasoningContent, "let me think")
}
if msgs[0].ModelName != "gpt-5.4-mini" {
t.Errorf("ModelName = %q, want %q", msgs[0].ModelName, "gpt-5.4-mini")
}
found, err := s.GetMessageByID(ctx, msg.ID)
if err != nil {
@@ -238,6 +245,9 @@ func TestStoreAddAndGetMessagesWithReasoningContent(t *testing.T) {
if found.ReasoningContent != "let me think" {
t.Errorf("GetMessageByID ReasoningContent = %q, want %q", found.ReasoningContent, "let me think")
}
if found.ModelName != "gpt-5.4-mini" {
t.Errorf("GetMessageByID ModelName = %q, want %q", found.ModelName, "gpt-5.4-mini")
}
}
func TestStoreAddMessageWithParts(t *testing.T) {
@@ -288,6 +298,7 @@ func TestStoreAddMessageWithPartsAndReasoningContent(t *testing.T) {
conv.ConversationID,
"assistant",
parts,
"gpt-5.4",
"need to inspect the file first",
10,
)
@@ -309,6 +320,9 @@ func TestStoreAddMessageWithPartsAndReasoningContent(t *testing.T) {
"need to inspect the file first",
)
}
if msgs[0].ModelName != "gpt-5.4" {
t.Errorf("ModelName = %q, want %q", msgs[0].ModelName, "gpt-5.4")
}
}
func TestStoreGetMessageCount(t *testing.T) {
+2
View File
@@ -22,6 +22,7 @@ type Message struct {
ConversationID int64 `json:"conversationId"`
Role string `json:"role"`
Content string `json:"content"`
ModelName string `json:"modelName,omitempty"`
ReasoningContent string `json:"reasoningContent,omitempty"`
TokenCount int `json:"tokenCount"`
CreatedAt time.Time `json:"createdAt"`
@@ -135,6 +136,7 @@ func EstimateMessageTokens(msg Message) int {
pm := providers.Message{
Role: msg.Role,
Content: msg.Content,
ModelName: msg.ModelName,
ReasoningContent: msg.ReasoningContent,
}
+19
View File
@@ -66,6 +66,25 @@ func TestJSONLBackend_AddFullMessage(t *testing.T) {
}
}
func TestJSONLBackend_AddFullMessage_PreservesModelName(t *testing.T) {
b := newBackend(t)
msg := providers.Message{
Role: "assistant",
Content: "done",
ModelName: "gpt-5.4-mini",
}
b.AddFullMessage("s1", msg)
history := b.GetHistory("s1")
if len(history) != 1 {
t.Fatalf("got %d, want 1", len(history))
}
if history[0].ModelName != "gpt-5.4-mini" {
t.Fatalf("ModelName = %q, want %q", history[0].ModelName, "gpt-5.4-mini")
}
}
func TestJSONLBackend_Summary(t *testing.T) {
b := newBackend(t)
+15 -7
View File
@@ -50,6 +50,7 @@ type sessionChatMessage struct {
Role string `json:"role"`
Content string `json:"content"`
Kind string `json:"kind,omitempty"`
ModelName string `json:"model_name,omitempty"`
Media []string `json:"media,omitempty"`
Attachments []sessionChatAttachment `json:"attachments,omitempty"`
ToolCalls []utils.VisibleToolCall `json:"tool_calls,omitempty"`
@@ -510,6 +511,7 @@ func sessionTranscriptMessages(
chatMsg := sessionChatMessage{
Role: "user",
Content: msg.Content,
ModelName: msg.ModelName,
Media: append([]string(nil), msg.Media...),
Attachments: attachments,
}
@@ -529,9 +531,10 @@ func sessionTranscriptMessages(
toolCallsMsg, hasToolCallsMsg := assistantToolCallsMessage(
msg.ToolCalls,
msg.ModelName,
toolFeedbackMaxArgsLength,
)
visibleToolMessages := visibleAssistantToolMessages(msg.ToolCalls)
visibleToolMessages := visibleAssistantToolMessages(msg.ToolCalls, msg.ModelName)
// Pico web chat can persist both visible `message` tool output and a
// later plain assistant reply in the same turn. Hide only the fixed
@@ -556,6 +559,7 @@ func sessionTranscriptMessages(
chatMsg := sessionChatMessage{
Role: "assistant",
Content: content,
ModelName: msg.ModelName,
Media: append([]string(nil), msg.Media...),
Attachments: attachments,
}
@@ -682,14 +686,16 @@ func assistantThoughtMessage(msg providers.Message) (sessionChatMessage, bool) {
return sessionChatMessage{}, false
}
return sessionChatMessage{
Role: "assistant",
Content: reasoning,
Kind: "thought",
Role: "assistant",
Content: reasoning,
Kind: "thought",
ModelName: msg.ModelName,
}, true
}
func assistantToolCallsMessage(
toolCalls []providers.ToolCall,
modelName string,
toolFeedbackMaxArgsLength int,
) (sessionChatMessage, bool) {
if len(toolCalls) == 0 {
@@ -707,6 +713,7 @@ func assistantToolCallsMessage(
return sessionChatMessage{
Role: "assistant",
Kind: "tool_calls",
ModelName: modelName,
ToolCalls: visibleToolCalls,
}, true
}
@@ -718,7 +725,7 @@ func visibleAssistantToolArgsPreview(
return utils.VisibleToolCallArgumentsPreview(tc, toolFeedbackMaxArgsLength)
}
func visibleAssistantToolMessages(toolCalls []providers.ToolCall) []sessionChatMessage {
func visibleAssistantToolMessages(toolCalls []providers.ToolCall, modelName string) []sessionChatMessage {
if len(toolCalls) == 0 {
return nil
}
@@ -734,8 +741,9 @@ func visibleAssistantToolMessages(toolCalls []providers.ToolCall) []sessionChatM
continue
}
messages = append(messages, sessionChatMessage{
Role: "assistant",
Content: content,
Role: "assistant",
Content: content,
ModelName: modelName,
})
}
+16 -3
View File
@@ -564,7 +564,7 @@ func TestHandleGetSession_ReconstructsThoughtFromAssistantReasoningContent(t *te
sessionKey := picoSessionPrefix + "detail-reasoning-content"
for _, msg := range []providers.Message{
{Role: "user", Content: "hello"},
{Role: "assistant", Content: "final visible answer", ReasoningContent: "internal chain of thought"},
{Role: "assistant", Content: "final visible answer", ModelName: "gpt-5.4", ReasoningContent: "internal chain of thought"},
} {
if err := store.AddFullMessage(nil, sessionKey, msg); err != nil {
t.Fatalf("AddFullMessage() error = %v", err)
@@ -597,9 +597,15 @@ func TestHandleGetSession_ReconstructsThoughtFromAssistantReasoningContent(t *te
resp.Messages[1].Kind != "thought" {
t.Fatalf("thought message = %#v, want assistant thought/internal chain of thought", resp.Messages[1])
}
if resp.Messages[1].ModelName != "gpt-5.4" {
t.Fatalf("thought model_name = %q, want %q", resp.Messages[1].ModelName, "gpt-5.4")
}
if resp.Messages[2].Role != "assistant" || resp.Messages[2].Content != "final visible answer" {
t.Fatalf("final message = %#v, want assistant/final visible answer", resp.Messages[2])
}
if resp.Messages[2].ModelName != "gpt-5.4" {
t.Fatalf("final model_name = %q, want %q", resp.Messages[2].ModelName, "gpt-5.4")
}
}
func TestHandleGetSession_ReconstructsRefreshMatrixForThoughtAndToolSummary(t *testing.T) {
@@ -725,8 +731,9 @@ func TestHandleGetSession_ReconstructsVisibleMessageToolOutputWithoutDuplicateSu
for _, msg := range []providers.Message{
{Role: "user", Content: "test"},
{
Role: "assistant",
Content: "",
Role: "assistant",
Content: "",
ModelName: "gpt-5.4-mini",
ToolCalls: []providers.ToolCall{
{
ID: "call_1",
@@ -771,9 +778,15 @@ func TestHandleGetSession_ReconstructsVisibleMessageToolOutputWithoutDuplicateSu
t.Fatalf("first message = %#v, want user/test", resp.Messages[0])
}
assertVisibleToolCallMessage(t, resp.Messages[1], "message")
if resp.Messages[1].ModelName != "gpt-5.4-mini" {
t.Fatalf("tool_calls model_name = %q, want %q", resp.Messages[1].ModelName, "gpt-5.4-mini")
}
if resp.Messages[2].Role != "assistant" || resp.Messages[2].Content != "visible tool output" {
t.Fatalf("assistant message = %#v, want visible tool output", resp.Messages[2])
}
if resp.Messages[2].ModelName != "gpt-5.4-mini" {
t.Fatalf("visible tool output model_name = %q, want %q", resp.Messages[2].ModelName, "gpt-5.4-mini")
}
}
func TestHandleGetSession_PreservesFinalAssistantReplyAfterMessageToolOutput(t *testing.T) {
+1
View File
@@ -15,6 +15,7 @@ export interface SessionDetail {
role: "user" | "assistant"
content: string
kind?: "normal" | "thought" | "tool_calls"
model_name?: string
media?: string[]
attachments?: {
type?: "image" | "audio" | "video" | "file"
@@ -33,6 +33,7 @@ interface AssistantMessageProps {
content: string
attachments?: ChatAttachment[]
kind?: AssistantMessageKind
modelName?: string
toolCalls?: ChatToolCall[]
timestamp?: string | number
}
@@ -41,6 +42,7 @@ export function AssistantMessage({
content,
attachments = [],
kind = "normal",
modelName,
toolCalls = [],
timestamp = "",
}: AssistantMessageProps) {
@@ -66,13 +68,20 @@ export function AssistantMessage({
const copyMessageLabel = isCopied
? t("chat.copiedLabel")
: t("chat.copyMessage")
const trimmedModelName = modelName?.trim() ?? ""
return (
<div className="group flex w-full flex-col gap-1.5">
{!isCollapsedBlock && (
<div className="text-muted-foreground/60 flex items-center justify-between gap-2 px-1 text-xs opacity-70">
<div className="text-muted-foreground/60 flex items-center justify-between gap-2 px-1 text-xs opacity-70">
<div className="flex items-center gap-2">
<span>PicoClaw</span>
{trimmedModelName && (
<>
<span className="opacity-50"></span>
<span>{trimmedModelName}</span>
</>
)}
{formattedTimestamp && (
<>
<span className="opacity-50"></span>
@@ -104,6 +113,9 @@ export function AssistantMessage({
<IconTool className="size-3.5" />
)}
<span>{collapsedLabel}</span>
{trimmedModelName && (
<span className="text-muted-foreground/45">{trimmedModelName}</span>
)}
</div>
<IconChevronDown
className={cn(
@@ -376,6 +376,7 @@ export function ChatPage() {
content={msg.content}
attachments={msg.attachments}
kind={msg.kind}
modelName={msg.modelName}
toolCalls={msg.toolCalls}
timestamp={msg.timestamp}
/>
+2 -1
View File
@@ -50,6 +50,7 @@ export async function loadSessionMessages(
role: message.role,
content: message.content,
kind: message.role === "assistant" ? (message.kind ?? "normal") : undefined,
modelName: message.model_name,
toolCalls:
message.role === "assistant"
? parseToolCallsValue(message.tool_calls)
@@ -86,7 +87,7 @@ function messageSignature(message: ChatMessage): string {
return `${message.role}\u0000${message.content}\u0000${normalizeMessageTimestamp(
message.timestamp,
)}\u0000${message.kind ?? ""}\u0000${attachmentSignature}\u0000${toolCallsSignature(
)}\u0000${message.kind ?? ""}\u0000${message.modelName ?? ""}\u0000${attachmentSignature}\u0000${toolCallsSignature(
message.toolCalls,
)}`
}
@@ -83,6 +83,14 @@ function parseContextUsage(
}
}
function parseModelName(payload: Record<string, unknown>): string | undefined {
if (typeof payload.model_name !== "string") {
return undefined
}
const modelName = payload.model_name.trim()
return modelName || undefined
}
export function handlePicoMessage(
message: PicoMessage,
expectedSessionId: string,
@@ -102,6 +110,7 @@ export function handlePicoMessage(
const attachments = parseAttachments(payload)
const contextUsage = parseContextUsage(payload)
const isPlaceholder = payload.placeholder === true
const modelName = parseModelName(payload)
const timestamp =
message.timestamp !== undefined &&
Number.isFinite(Number(message.timestamp))
@@ -116,6 +125,7 @@ export function handlePicoMessage(
role: "assistant",
content,
kind,
...(modelName ? { modelName } : {}),
...(toolCalls ? { toolCalls } : {}),
attachments,
timestamp,
@@ -135,6 +145,7 @@ export function handlePicoMessage(
const messageId = payload.message_id as string
const attachments = parseAttachments(payload)
const contextUsage = parseContextUsage(payload)
const modelName = parseModelName(payload)
const timestamp =
message.timestamp !== undefined &&
Number.isFinite(Number(message.timestamp))
@@ -160,6 +171,7 @@ export function handlePicoMessage(
content,
kind,
toolCalls,
...(modelName ? { modelName } : {}),
...(attachments ? { attachments } : {}),
}
})
@@ -178,6 +190,7 @@ export function handlePicoMessage(
content,
kind,
toolCalls,
...(modelName ? { modelName } : {}),
...(attachments ? { attachments } : {}),
timestamp,
},
+1
View File
@@ -44,6 +44,7 @@ export interface ChatMessage {
content: string
timestamp: number | string
kind?: AssistantMessageKind
modelName?: string
attachments?: ChatAttachment[]
toolCalls?: ChatToolCall[]
}