mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-05-25 16:00:35 +00:00
feat(chat,seahorse): persist and display model_name across history (#2897)
* feat(chat,seahorse): persist and display model_name across history * test(seahorse): fix lint regressions in repair coverage * fix(pico): preserve model_name in live updates * fix(pico): preserve model_name through live stream wrappers
This commit is contained in:
@@ -600,6 +600,12 @@ func (al *AgentLoop) runAgentLoop(
|
||||
Content: result.finalContent,
|
||||
ContextUsage: computeContextUsage(agent, opts.Dispatch.SessionKey),
|
||||
}
|
||||
if modelName := strings.TrimSpace(result.modelName); modelName != "" {
|
||||
if msg.Context.Raw == nil {
|
||||
msg.Context.Raw = make(map[string]string, 1)
|
||||
}
|
||||
msg.Context.Raw["model_name"] = modelName
|
||||
}
|
||||
markFinalOutbound(&msg)
|
||||
al.bus.PublishOutbound(ctx, msg)
|
||||
}
|
||||
|
||||
+32
-11
@@ -102,7 +102,10 @@ func (al *AgentLoop) targetReasoningChannelID(channelName string) (chatID string
|
||||
return ""
|
||||
}
|
||||
|
||||
func (al *AgentLoop) publishPicoReasoning(ctx context.Context, reasoningContent, chatID, sessionKey string) {
|
||||
func (al *AgentLoop) publishPicoReasoning(
|
||||
ctx context.Context,
|
||||
reasoningContent, chatID, sessionKey, modelName string,
|
||||
) {
|
||||
if reasoningContent == "" || chatID == "" {
|
||||
return
|
||||
}
|
||||
@@ -114,13 +117,16 @@ func (al *AgentLoop) publishPicoReasoning(ctx context.Context, reasoningContent,
|
||||
pubCtx, pubCancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
defer pubCancel()
|
||||
|
||||
raw := map[string]string{metadataKeyMessageKind: messageKindThought}
|
||||
if trimmedModelName := strings.TrimSpace(modelName); trimmedModelName != "" {
|
||||
raw["model_name"] = trimmedModelName
|
||||
}
|
||||
|
||||
if err := al.bus.PublishOutbound(pubCtx, bus.OutboundMessage{
|
||||
Context: bus.InboundContext{
|
||||
Channel: "pico",
|
||||
ChatID: chatID,
|
||||
Raw: map[string]string{
|
||||
metadataKeyMessageKind: messageKindThought,
|
||||
},
|
||||
Raw: raw,
|
||||
},
|
||||
SessionKey: sessionKey,
|
||||
Content: reasoningContent,
|
||||
@@ -143,6 +149,7 @@ func (al *AgentLoop) publishPicoReasoning(ctx context.Context, reasoningContent,
|
||||
func (al *AgentLoop) publishPicoToolCallInterim(
|
||||
ctx context.Context,
|
||||
ts *turnState,
|
||||
modelName string,
|
||||
reasoningContent string,
|
||||
content string,
|
||||
toolCalls []providers.ToolCall,
|
||||
@@ -155,7 +162,14 @@ func (al *AgentLoop) publishPicoToolCallInterim(
|
||||
pubCtx, pubCancel := context.WithTimeout(ctx, 3*time.Second)
|
||||
err := al.bus.PublishOutbound(
|
||||
pubCtx,
|
||||
outboundMessageForTurnWithKind(ts, reasoningContent, messageKindThought),
|
||||
outboundMessageForTurnWithOptions(
|
||||
ts,
|
||||
reasoningContent,
|
||||
outboundTurnMessageOptions{
|
||||
kind: messageKindThought,
|
||||
modelName: modelName,
|
||||
},
|
||||
),
|
||||
)
|
||||
pubCancel()
|
||||
if err != nil && !errors.Is(err, context.DeadlineExceeded) &&
|
||||
@@ -182,7 +196,12 @@ func (al *AgentLoop) publishPicoToolCallInterim(
|
||||
|
||||
if strings.TrimSpace(content) != "" && !duplicateToolCallContent {
|
||||
pubCtx, pubCancel := context.WithTimeout(ctx, 3*time.Second)
|
||||
err := al.bus.PublishOutbound(pubCtx, outboundMessageForTurn(ts, content))
|
||||
err := al.bus.PublishOutbound(
|
||||
pubCtx,
|
||||
outboundMessageForTurnWithOptions(ts, content, outboundTurnMessageOptions{
|
||||
modelName: modelName,
|
||||
}),
|
||||
)
|
||||
pubCancel()
|
||||
if err != nil && !errors.Is(err, context.DeadlineExceeded) &&
|
||||
!errors.Is(err, context.Canceled) &&
|
||||
@@ -209,11 +228,13 @@ func (al *AgentLoop) publishPicoToolCallInterim(
|
||||
return
|
||||
}
|
||||
|
||||
msg := outboundMessageForTurnWithKind(ts, "", messageKindToolCalls)
|
||||
if msg.Context.Raw == nil {
|
||||
msg.Context.Raw = map[string]string{}
|
||||
}
|
||||
msg.Context.Raw[metadataKeyToolCalls] = string(rawToolCalls)
|
||||
msg := outboundMessageForTurnWithOptions(ts, "", outboundTurnMessageOptions{
|
||||
kind: messageKindToolCalls,
|
||||
modelName: modelName,
|
||||
raw: map[string]string{
|
||||
metadataKeyToolCalls: string(rawToolCalls),
|
||||
},
|
||||
})
|
||||
|
||||
pubCtx, pubCancel := context.WithTimeout(ctx, 3*time.Second)
|
||||
err = al.bus.PublishOutbound(pubCtx, msg)
|
||||
|
||||
@@ -312,7 +312,7 @@ func TestPublishPicoReasoningIncludesSessionKey(t *testing.T) {
|
||||
defer cleanup()
|
||||
_ = provider
|
||||
|
||||
al.publishPicoReasoning(context.Background(), "reasoning", "pico-chat", "session-1")
|
||||
al.publishPicoReasoning(context.Background(), "reasoning", "pico-chat", "session-1", "")
|
||||
|
||||
select {
|
||||
case outbound := <-msgBus.OutboundChan():
|
||||
|
||||
@@ -96,15 +96,46 @@ func markFinalOutbound(msg *bus.OutboundMessage) {
|
||||
msg.Context.Raw[metadataKeyOutboundKind] = outboundKindFinal
|
||||
}
|
||||
|
||||
func outboundMessageForTurnWithKind(ts *turnState, content, kind string) bus.OutboundMessage {
|
||||
type outboundTurnMessageOptions struct {
|
||||
kind string
|
||||
modelName string
|
||||
raw map[string]string
|
||||
}
|
||||
|
||||
func outboundMessageForTurnWithOptions(
|
||||
ts *turnState,
|
||||
content string,
|
||||
opts outboundTurnMessageOptions,
|
||||
) bus.OutboundMessage {
|
||||
msg := outboundMessageForTurn(ts, content)
|
||||
if strings.TrimSpace(kind) == "" {
|
||||
trimmedKind := strings.TrimSpace(opts.kind)
|
||||
trimmedModelName := strings.TrimSpace(opts.modelName)
|
||||
rawCount := len(opts.raw)
|
||||
if trimmedKind != "" {
|
||||
rawCount++
|
||||
}
|
||||
if trimmedModelName != "" {
|
||||
rawCount++
|
||||
}
|
||||
if rawCount == 0 {
|
||||
return msg
|
||||
}
|
||||
|
||||
if msg.Context.Raw == nil {
|
||||
msg.Context.Raw = make(map[string]string, 1)
|
||||
msg.Context.Raw = make(map[string]string, rawCount)
|
||||
}
|
||||
if trimmedKind != "" {
|
||||
msg.Context.Raw[metadataKeyMessageKind] = trimmedKind
|
||||
}
|
||||
if trimmedModelName != "" {
|
||||
msg.Context.Raw["model_name"] = trimmedModelName
|
||||
}
|
||||
for key, value := range opts.raw {
|
||||
if strings.TrimSpace(key) == "" {
|
||||
continue
|
||||
}
|
||||
msg.Context.Raw[key] = value
|
||||
}
|
||||
msg.Context.Raw[metadataKeyMessageKind] = kind
|
||||
return msg
|
||||
}
|
||||
|
||||
@@ -521,8 +552,9 @@ func hasMediaRefs(messages []providers.Message) bool {
|
||||
|
||||
func sideQuestionModelName(agent *AgentInstance, usedLight bool) string {
|
||||
if usedLight && len(agent.LightCandidates) > 0 {
|
||||
// Use the first light candidate's model
|
||||
return agent.LightCandidates[0].Model
|
||||
if name := resolvedCandidateModelName(agent.LightCandidates, ""); name != "" {
|
||||
return name
|
||||
}
|
||||
}
|
||||
return agent.Model
|
||||
}
|
||||
@@ -538,6 +570,14 @@ func modelNameFromIdentityKey(identityKey string) string {
|
||||
return identityKey
|
||||
}
|
||||
|
||||
func modelAliasFromCandidateIdentityKey(identityKey string) string {
|
||||
const prefix = "model_name:"
|
||||
if !strings.HasPrefix(identityKey, prefix) {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(strings.TrimPrefix(identityKey, prefix))
|
||||
}
|
||||
|
||||
func closeProviderIfStateful(provider providers.LLMProvider) {
|
||||
if stateful, ok := provider.(providers.StatefulProvider); ok {
|
||||
stateful.Close()
|
||||
|
||||
@@ -197,6 +197,7 @@ func providerToSeahorseMessage(msg protocoltypes.Message) seahorse.Message {
|
||||
result := seahorse.Message{
|
||||
Role: msg.Role,
|
||||
Content: msg.Content,
|
||||
ModelName: msg.ModelName,
|
||||
ReasoningContent: msg.ReasoningContent,
|
||||
TokenCount: tokenizer.EstimateMessageTokens(msg),
|
||||
}
|
||||
@@ -243,6 +244,7 @@ func seahorseToProviderMessages(result *seahorse.AssembleResult) []protocoltypes
|
||||
pm := protocoltypes.Message{
|
||||
Role: msg.Role,
|
||||
Content: msg.Content,
|
||||
ModelName: msg.ModelName,
|
||||
ReasoningContent: msg.ReasoningContent,
|
||||
}
|
||||
|
||||
|
||||
@@ -174,6 +174,7 @@ func TestProviderToSeahorseMessageWithReasoning(t *testing.T) {
|
||||
msg := protocoltypes.Message{
|
||||
Role: "assistant",
|
||||
Content: "response text",
|
||||
ModelName: "gpt-5.4-mini",
|
||||
ReasoningContent: "I thought about this carefully",
|
||||
}
|
||||
|
||||
@@ -181,6 +182,9 @@ func TestProviderToSeahorseMessageWithReasoning(t *testing.T) {
|
||||
if result.ReasoningContent != "I thought about this carefully" {
|
||||
t.Errorf("ReasoningContent = %q, want 'I thought about this carefully'", result.ReasoningContent)
|
||||
}
|
||||
if result.ModelName != "gpt-5.4-mini" {
|
||||
t.Errorf("ModelName = %q, want %q", result.ModelName, "gpt-5.4-mini")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSeahorseToProviderMessagesWithReasoning(t *testing.T) {
|
||||
@@ -189,6 +193,7 @@ func TestSeahorseToProviderMessagesWithReasoning(t *testing.T) {
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "response",
|
||||
ModelName: "gpt-5.4",
|
||||
ReasoningContent: "thinking process",
|
||||
},
|
||||
},
|
||||
@@ -201,6 +206,9 @@ func TestSeahorseToProviderMessagesWithReasoning(t *testing.T) {
|
||||
if messages[0].ReasoningContent != "thinking process" {
|
||||
t.Errorf("ReasoningContent = %q, want 'thinking process'", messages[0].ReasoningContent)
|
||||
}
|
||||
if messages[0].ModelName != "gpt-5.4" {
|
||||
t.Errorf("ModelName = %q, want %q", messages[0].ModelName, "gpt-5.4")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSeahorseToProviderMessages(t *testing.T) {
|
||||
|
||||
@@ -75,6 +75,7 @@ func candidateFromModelConfig(
|
||||
return providers.FallbackCandidate{
|
||||
Provider: protocol,
|
||||
Model: modelID,
|
||||
DisplayName: strings.TrimSpace(mc.ModelName),
|
||||
RPM: mc.RPM,
|
||||
IdentityKey: modelConfigIdentityKey(mc),
|
||||
}, true
|
||||
@@ -147,8 +148,9 @@ func resolveModelCandidate(
|
||||
}
|
||||
|
||||
return providers.FallbackCandidate{
|
||||
Provider: ref.Provider,
|
||||
Model: ref.Model,
|
||||
Provider: ref.Provider,
|
||||
Model: ref.Model,
|
||||
DisplayName: raw,
|
||||
}, true
|
||||
}
|
||||
|
||||
@@ -197,6 +199,18 @@ func resolvedCandidateProvider(candidates []providers.FallbackCandidate, fallbac
|
||||
return fallback
|
||||
}
|
||||
|
||||
func resolvedCandidateModelName(candidates []providers.FallbackCandidate, fallback string) string {
|
||||
if len(candidates) > 0 {
|
||||
if name := modelAliasFromCandidateIdentityKey(candidates[0].IdentityKey); strings.TrimSpace(name) != "" {
|
||||
return name
|
||||
}
|
||||
if displayName := strings.TrimSpace(candidates[0].DisplayName); displayName != "" {
|
||||
return displayName
|
||||
}
|
||||
}
|
||||
return strings.TrimSpace(fallback)
|
||||
}
|
||||
|
||||
func resolvedModelConfig(cfg *config.Config, modelName, workspace string) (*config.ModelConfig, error) {
|
||||
if cfg == nil {
|
||||
return nil, fmt.Errorf("config is nil")
|
||||
|
||||
@@ -7,6 +7,55 @@ import (
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
func TestModelNameFromIdentityKey_LegacyProviderModel(t *testing.T) {
|
||||
if got := modelNameFromIdentityKey("openai/gpt-5.4"); got != "gpt-5.4" {
|
||||
t.Fatalf("modelNameFromIdentityKey() = %q, want %q", got, "gpt-5.4")
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelNameFromIdentityKey_PreservesNonLegacyIdentity(t *testing.T) {
|
||||
if got := modelNameFromIdentityKey("model_name:primary"); got != "model_name:primary" {
|
||||
t.Fatalf("modelNameFromIdentityKey() = %q, want %q", got, "model_name:primary")
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelAliasFromCandidateIdentityKey(t *testing.T) {
|
||||
if got := modelAliasFromCandidateIdentityKey("model_name:primary"); got != "primary" {
|
||||
t.Fatalf("modelAliasFromCandidateIdentityKey() = %q, want %q", got, "primary")
|
||||
}
|
||||
if got := modelAliasFromCandidateIdentityKey("openai/gpt-5.4"); got != "" {
|
||||
t.Fatalf("modelAliasFromCandidateIdentityKey() = %q, want empty", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolvedCandidateModelName_PrefersIdentityAlias(t *testing.T) {
|
||||
got := resolvedCandidateModelName([]providers.FallbackCandidate{
|
||||
{Provider: "openai", Model: "gpt-5.4", IdentityKey: "model_name:primary"},
|
||||
}, "fallback-model")
|
||||
if got != "primary" {
|
||||
t.Fatalf("resolvedCandidateModelName() = %q, want %q", got, "primary")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolvedCandidateModelName_DoesNotScanFallbackAliases(t *testing.T) {
|
||||
got := resolvedCandidateModelName([]providers.FallbackCandidate{
|
||||
{Provider: "openai", Model: "gpt-5.4"},
|
||||
{Provider: "openai", Model: "gpt-5.4-mini", IdentityKey: "model_name:fallback"},
|
||||
}, "primary-model")
|
||||
if got != "primary-model" {
|
||||
t.Fatalf("resolvedCandidateModelName() = %q, want %q", got, "primary-model")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolvedCandidateModelName_UsesCandidateDisplayName(t *testing.T) {
|
||||
got := resolvedCandidateModelName([]providers.FallbackCandidate{
|
||||
{Provider: "openai", Model: "gpt-5.4", DisplayName: "gpt-5.4-display"},
|
||||
}, "fallback-model")
|
||||
if got != "gpt-5.4-display" {
|
||||
t.Fatalf("resolvedCandidateModelName() = %q, want %q", got, "gpt-5.4-display")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveActiveModelConfig_PrefersCandidateIdentityKey(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
ModelList: []*config.ModelConfig{
|
||||
|
||||
@@ -180,7 +180,11 @@ toolLoop:
|
||||
toolFeedbackArgsPreview(toolArgs, toolFeedbackMaxLen),
|
||||
)
|
||||
fbCtx, fbCancel := context.WithTimeout(turnCtx, 3*time.Second)
|
||||
_ = al.bus.PublishOutbound(fbCtx, outboundMessageForTurnWithKind(ts, feedbackMsg, messageKindToolFeedback))
|
||||
_ = al.bus.PublishOutbound(fbCtx, outboundMessageForTurnWithOptions(
|
||||
ts,
|
||||
feedbackMsg,
|
||||
outboundTurnMessageOptions{kind: messageKindToolFeedback},
|
||||
))
|
||||
fbCancel()
|
||||
}
|
||||
|
||||
@@ -467,7 +471,11 @@ toolLoop:
|
||||
toolFeedbackArgsPreview(toolArgs, toolFeedbackMaxLen),
|
||||
)
|
||||
fbCtx, fbCancel := context.WithTimeout(turnCtx, 3*time.Second)
|
||||
_ = al.bus.PublishOutbound(fbCtx, outboundMessageForTurnWithKind(ts, feedbackMsg, messageKindToolFeedback))
|
||||
_ = al.bus.PublishOutbound(fbCtx, outboundMessageForTurnWithOptions(
|
||||
ts,
|
||||
feedbackMsg,
|
||||
outboundTurnMessageOptions{kind: messageKindToolFeedback},
|
||||
))
|
||||
fbCancel()
|
||||
}
|
||||
|
||||
|
||||
@@ -33,6 +33,7 @@ func (p *Pipeline) Finalize(
|
||||
ts.setPhase(TurnPhaseCompleted)
|
||||
return turnResult{
|
||||
finalContent: finalContent,
|
||||
modelName: exec.llmModelName,
|
||||
status: turnStatus,
|
||||
followUps: append([]bus.InboundMessage(nil), ts.followUps...),
|
||||
}, nil
|
||||
@@ -44,6 +45,7 @@ func (p *Pipeline) Finalize(
|
||||
finalMsg := providers.Message{
|
||||
Role: "assistant",
|
||||
Content: finalContent,
|
||||
ModelName: exec.llmModelName,
|
||||
ReasoningContent: responseReasoningContent(exec.response),
|
||||
}
|
||||
ts.agent.Sessions.AddFullMessage(ts.sessionKey, finalMsg)
|
||||
@@ -80,24 +82,10 @@ func (p *Pipeline) Finalize(
|
||||
// so the final answer is still delivered outside normal SendResponse.
|
||||
if ((streamErr != nil && !isConfiguredStreamingVisibleError(streamErr)) || exec.streamingFallback) &&
|
||||
!ts.opts.SendResponse && ts.opts.AllowInterimPicoPublish && finalContent != "" {
|
||||
agentID, sessionKey, scope := outboundTurnMetadata(
|
||||
ts.agent.ID,
|
||||
ts.opts.Dispatch.SessionKey,
|
||||
ts.opts.Dispatch.SessionScope,
|
||||
)
|
||||
msg := bus.OutboundMessage{
|
||||
Context: outboundContextFromInbound(
|
||||
ts.opts.Dispatch.InboundContext,
|
||||
ts.opts.Dispatch.Channel(),
|
||||
ts.opts.Dispatch.ChatID(),
|
||||
ts.opts.Dispatch.ReplyToMessageID(),
|
||||
),
|
||||
AgentID: agentID,
|
||||
SessionKey: sessionKey,
|
||||
Scope: scope,
|
||||
Content: finalContent,
|
||||
ContextUsage: contextUsage,
|
||||
}
|
||||
msg := outboundMessageForTurnWithOptions(ts, finalContent, outboundTurnMessageOptions{
|
||||
modelName: exec.llmModelName,
|
||||
})
|
||||
msg.ContextUsage = contextUsage
|
||||
markFinalOutbound(&msg)
|
||||
_ = al.bus.PublishOutbound(turnCtx, msg)
|
||||
}
|
||||
@@ -112,6 +100,7 @@ func (p *Pipeline) Finalize(
|
||||
ts.setPhase(TurnPhaseCompleted)
|
||||
return turnResult{
|
||||
finalContent: finalContent,
|
||||
modelName: exec.llmModelName,
|
||||
status: turnStatus,
|
||||
followUps: append([]bus.InboundMessage(nil), ts.followUps...),
|
||||
}, nil
|
||||
|
||||
@@ -200,6 +200,16 @@ func (p *Pipeline) CallLLM(
|
||||
map[string]any{"agent_id": ts.agent.ID, "iteration": iteration},
|
||||
)
|
||||
}
|
||||
for _, candidate := range exec.activeCandidates {
|
||||
if candidate.StableKey() != fbResult.IdentityKey {
|
||||
continue
|
||||
}
|
||||
exec.llmModelName = resolvedCandidateModelName(
|
||||
[]providers.FallbackCandidate{candidate},
|
||||
exec.llmModelName,
|
||||
)
|
||||
break
|
||||
}
|
||||
return fbResult.Response, nil
|
||||
}
|
||||
return exec.activeProvider.Chat(providerCtx, messagesForCall, toolDefsForCall, exec.llmModel, exec.llmOpts)
|
||||
@@ -477,7 +487,7 @@ func (p *Pipeline) CallLLM(
|
||||
// Publish pico thoughts before the turn context is canceled at return time.
|
||||
// The async variant can race with turn teardown and intermittently drop the
|
||||
// thought message in CI even though the LLM produced reasoning content.
|
||||
al.publishPicoReasoning(turnCtx, reasoningContent, ts.chatID, ts.sessionKey)
|
||||
al.publishPicoReasoning(turnCtx, reasoningContent, ts.chatID, ts.sessionKey, exec.llmModelName)
|
||||
}
|
||||
} else {
|
||||
go al.handleReasoning(
|
||||
@@ -564,6 +574,7 @@ func (p *Pipeline) CallLLM(
|
||||
assistantMsg := providers.Message{
|
||||
Role: "assistant",
|
||||
Content: exec.response.Content,
|
||||
ModelName: exec.llmModelName,
|
||||
ReasoningContent: reasoningContent,
|
||||
}
|
||||
for _, tc := range exec.normalizedToolCalls {
|
||||
@@ -607,6 +618,7 @@ func (p *Pipeline) CallLLM(
|
||||
al.publishPicoToolCallInterim(
|
||||
turnCtx,
|
||||
ts,
|
||||
exec.llmModelName,
|
||||
reasoningContent,
|
||||
exec.response.Content,
|
||||
assistantMsg.ToolCalls,
|
||||
|
||||
@@ -89,6 +89,11 @@ func (p *Pipeline) SetupTurn(ctx context.Context, ts *turnState) (*turnExecution
|
||||
if usedLight && ts.agent.LightProvider != nil {
|
||||
activeProvider = ts.agent.LightProvider
|
||||
}
|
||||
activeModelName := strings.TrimSpace(ts.agent.Model)
|
||||
if usedLight {
|
||||
activeModelName = strings.TrimSpace(sideQuestionModelName(ts.agent, true))
|
||||
}
|
||||
activeModelName = resolvedCandidateModelName(activeCandidates, activeModelName)
|
||||
|
||||
exec := newTurnExecution(
|
||||
ts.agent,
|
||||
@@ -106,6 +111,7 @@ func (p *Pipeline) SetupTurn(ctx context.Context, ts *turnState) (*turnExecution
|
||||
activeModel,
|
||||
p.Cfg.Agents.Defaults.Provider,
|
||||
)
|
||||
exec.llmModelName = activeModelName
|
||||
exec.activeProvider = activeProvider
|
||||
exec.usedLight = usedLight
|
||||
|
||||
|
||||
@@ -50,9 +50,10 @@ func (p *Pipeline) tryConfiguredStreamingLLM(
|
||||
}
|
||||
|
||||
publisher := &streamingChunkPublisher{
|
||||
streamer: streamer,
|
||||
channel: ts.channel,
|
||||
chatID: ts.chatID,
|
||||
streamer: streamer,
|
||||
channel: ts.channel,
|
||||
chatID: ts.chatID,
|
||||
modelName: exec.llmModelName,
|
||||
}
|
||||
|
||||
logger.DebugCF("agent", "configured streaming enabled", map[string]any{
|
||||
@@ -371,6 +372,7 @@ type streamingChunkPublisher struct {
|
||||
streamer bus.Streamer
|
||||
channel string
|
||||
chatID string
|
||||
modelName string
|
||||
published bool
|
||||
reasoningPublished bool
|
||||
err error
|
||||
@@ -380,6 +382,9 @@ func (p *streamingChunkPublisher) Update(ctx context.Context, accumulated string
|
||||
if p == nil || p.streamer == nil || strings.TrimSpace(accumulated) == "" {
|
||||
return
|
||||
}
|
||||
if setter, ok := p.streamer.(interface{ SetModelName(modelName string) }); ok {
|
||||
setter.SetModelName(p.modelName)
|
||||
}
|
||||
if err := p.streamer.Update(ctx, accumulated); err != nil {
|
||||
p.err = err
|
||||
logger.WarnCF("agent", "stream update failed", map[string]any{
|
||||
@@ -396,6 +401,9 @@ func (p *streamingChunkPublisher) UpdateReasoning(ctx context.Context, accumulat
|
||||
if p == nil || p.streamer == nil || strings.TrimSpace(accumulated) == "" {
|
||||
return
|
||||
}
|
||||
if setter, ok := p.streamer.(interface{ SetModelName(modelName string) }); ok {
|
||||
setter.SetModelName(p.modelName)
|
||||
}
|
||||
reasoningStreamer, ok := p.streamer.(bus.ReasoningStreamer)
|
||||
if !ok {
|
||||
return
|
||||
@@ -434,6 +442,9 @@ func (p *streamingChunkPublisher) Finalize(ctx context.Context, content string,
|
||||
if strings.TrimSpace(content) == "" && !p.published {
|
||||
return nil
|
||||
}
|
||||
if setter, ok := p.streamer.(interface{ SetModelName(modelName string) }); ok {
|
||||
setter.SetModelName(p.modelName)
|
||||
}
|
||||
var err error
|
||||
if streamer, ok := p.streamer.(bus.ContextUsageStreamer); ok {
|
||||
err = streamer.FinalizeWithContext(ctx, content, contextUsage)
|
||||
|
||||
@@ -570,6 +570,9 @@ func TestConfiguredStreamingFinalFlushFailureBeforeVisibleOutputPublishesFallbac
|
||||
if outbound.Content != "stream response" {
|
||||
t.Fatalf("fallback outbound content = %q, want stream response", outbound.Content)
|
||||
}
|
||||
if got := outbound.Context.Raw["model_name"]; got != "test-model" {
|
||||
t.Fatalf("fallback outbound model_name = %q, want %q", got, "test-model")
|
||||
}
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("expected fallback outbound after invisible final stream flush failure")
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/routing"
|
||||
"github.com/sipeed/picoclaw/pkg/session"
|
||||
)
|
||||
|
||||
@@ -37,6 +38,39 @@ func (p *simpleConvProvider) GetDefaultModel() string {
|
||||
return "simple-model"
|
||||
}
|
||||
|
||||
type sequenceProvider struct {
|
||||
responses []*providers.LLMResponse
|
||||
errors []error
|
||||
callCount int
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func (p *sequenceProvider) Chat(
|
||||
ctx context.Context,
|
||||
messages []providers.Message,
|
||||
tools []providers.ToolDefinition,
|
||||
model string,
|
||||
opts map[string]any,
|
||||
) (*providers.LLMResponse, error) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
idx := p.callCount
|
||||
p.callCount++
|
||||
|
||||
if idx < len(p.errors) && p.errors[idx] != nil {
|
||||
return nil, p.errors[idx]
|
||||
}
|
||||
if idx < len(p.responses) && p.responses[idx] != nil {
|
||||
return p.responses[idx], nil
|
||||
}
|
||||
return &providers.LLMResponse{Content: "ok", FinishReason: "stop"}, nil
|
||||
}
|
||||
|
||||
func (p *sequenceProvider) GetDefaultModel() string {
|
||||
return "sequence-model"
|
||||
}
|
||||
|
||||
type nativeSearchCaptureProvider struct {
|
||||
lastOpts map[string]any
|
||||
}
|
||||
@@ -271,6 +305,152 @@ func TestPipeline_CallLLM_SimpleResponse(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestPipeline_SetupTurn_ModelNameDoesNotUseFallbackAliasBeforeFallback(t *testing.T) {
|
||||
al, agent, cleanup := newTurnCoordTestLoop(t, &simpleConvProvider{})
|
||||
defer cleanup()
|
||||
|
||||
agent.Model = "primary-model"
|
||||
agent.Candidates = []providers.FallbackCandidate{
|
||||
{Provider: "openai", Model: "gpt-5.4"},
|
||||
{Provider: "anthropic", Model: "claude-sonnet", IdentityKey: "model_name:fallback-model"},
|
||||
}
|
||||
|
||||
pipeline := NewPipeline(al)
|
||||
ts := newTurnState(agent, makeTestProcessOpts("test-session"), turnEventScope{
|
||||
turnID: "turn-1",
|
||||
context: newTurnContext(nil, nil, nil),
|
||||
})
|
||||
|
||||
exec, err := pipeline.SetupTurn(context.Background(), ts)
|
||||
if err != nil {
|
||||
t.Fatalf("SetupTurn failed: %v", err)
|
||||
}
|
||||
if exec.llmModelName != "primary-model" {
|
||||
t.Fatalf("exec.llmModelName = %q, want %q", exec.llmModelName, "primary-model")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPipeline_CallLLM_UsesSuccessfulFallbackIdentityAlias(t *testing.T) {
|
||||
provider := &sequenceProvider{
|
||||
errors: []error{
|
||||
errors.New("status: 429 - rate limit exceeded"),
|
||||
nil,
|
||||
},
|
||||
responses: []*providers.LLMResponse{
|
||||
nil,
|
||||
{Content: "fallback answer", FinishReason: "stop"},
|
||||
},
|
||||
}
|
||||
al, agent, cleanup := newTurnCoordTestLoop(t, provider)
|
||||
defer cleanup()
|
||||
|
||||
agent.Model = "primary-model"
|
||||
agent.Candidates = []providers.FallbackCandidate{
|
||||
{Provider: "openai", Model: "gpt-5.4", IdentityKey: "model_name:primary"},
|
||||
{Provider: "openai", Model: "gpt-5.4", IdentityKey: "model_name:secondary"},
|
||||
}
|
||||
al.fallback = providers.NewFallbackChain(providers.NewCooldownTracker(), nil)
|
||||
|
||||
pipeline := NewPipeline(al)
|
||||
ts := newTurnState(agent, makeTestProcessOpts("test-session"), turnEventScope{
|
||||
turnID: "turn-1",
|
||||
context: newTurnContext(nil, nil, nil),
|
||||
})
|
||||
|
||||
exec, err := pipeline.SetupTurn(context.Background(), ts)
|
||||
if err != nil {
|
||||
t.Fatalf("SetupTurn failed: %v", err)
|
||||
}
|
||||
|
||||
ctrl, err := pipeline.CallLLM(context.Background(), context.Background(), ts, exec, 1)
|
||||
if err != nil {
|
||||
t.Fatalf("CallLLM failed: %v", err)
|
||||
}
|
||||
if ctrl != ControlBreak {
|
||||
t.Fatalf("expected ControlBreak, got %v", ctrl)
|
||||
}
|
||||
if exec.llmModelName != "secondary" {
|
||||
t.Fatalf("exec.llmModelName = %q, want %q", exec.llmModelName, "secondary")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPipeline_CallLLM_UsesSuccessfulFallbackDisplayNameWithoutAlias(t *testing.T) {
|
||||
provider := &sequenceProvider{
|
||||
errors: []error{
|
||||
errors.New("status: 429 - rate limit exceeded"),
|
||||
nil,
|
||||
},
|
||||
responses: []*providers.LLMResponse{
|
||||
nil,
|
||||
{Content: "fallback answer", FinishReason: "stop"},
|
||||
},
|
||||
}
|
||||
al, agent, cleanup := newTurnCoordTestLoop(t, provider)
|
||||
defer cleanup()
|
||||
|
||||
agent.Model = "primary-model"
|
||||
agent.Candidates = []providers.FallbackCandidate{
|
||||
{Provider: "openai", Model: "gpt-5.4", IdentityKey: "model_name:primary", DisplayName: "primary-model"},
|
||||
{Provider: "anthropic", Model: "claude-sonnet", DisplayName: "anthropic/claude-sonnet"},
|
||||
}
|
||||
al.fallback = providers.NewFallbackChain(providers.NewCooldownTracker(), nil)
|
||||
|
||||
pipeline := NewPipeline(al)
|
||||
ts := newTurnState(agent, makeTestProcessOpts("test-session"), turnEventScope{
|
||||
turnID: "turn-1",
|
||||
context: newTurnContext(nil, nil, nil),
|
||||
})
|
||||
|
||||
exec, err := pipeline.SetupTurn(context.Background(), ts)
|
||||
if err != nil {
|
||||
t.Fatalf("SetupTurn failed: %v", err)
|
||||
}
|
||||
|
||||
ctrl, err := pipeline.CallLLM(context.Background(), context.Background(), ts, exec, 1)
|
||||
if err != nil {
|
||||
t.Fatalf("CallLLM failed: %v", err)
|
||||
}
|
||||
if ctrl != ControlBreak {
|
||||
t.Fatalf("expected ControlBreak, got %v", ctrl)
|
||||
}
|
||||
if exec.llmModelName != "anthropic/claude-sonnet" {
|
||||
t.Fatalf("exec.llmModelName = %q, want %q", exec.llmModelName, "anthropic/claude-sonnet")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPipeline_SetupTurn_UsesLightCandidateDisplayName(t *testing.T) {
|
||||
al, agent, cleanup := newTurnCoordTestLoop(t, &simpleConvProvider{})
|
||||
defer cleanup()
|
||||
|
||||
agent.Model = "primary-model"
|
||||
agent.Candidates = []providers.FallbackCandidate{
|
||||
{Provider: "openai", Model: "gpt-5.4", IdentityKey: "model_name:primary", DisplayName: "primary-model"},
|
||||
}
|
||||
agent.LightCandidates = []providers.FallbackCandidate{
|
||||
{Provider: "openai", Model: "gpt-5.4-mini", IdentityKey: "model_name:light-model", DisplayName: "light-model"},
|
||||
}
|
||||
agent.Router = routing.New(routing.RouterConfig{LightModel: "light-model", Threshold: 1})
|
||||
|
||||
pipeline := NewPipeline(al)
|
||||
opts := makeTestProcessOpts("test-session")
|
||||
opts.UserMessage = ""
|
||||
ts := newTurnState(agent, opts, turnEventScope{
|
||||
turnID: "turn-1",
|
||||
context: newTurnContext(nil, nil, nil),
|
||||
})
|
||||
|
||||
exec, err := pipeline.SetupTurn(context.Background(), ts)
|
||||
if err != nil {
|
||||
t.Fatalf("SetupTurn failed: %v", err)
|
||||
}
|
||||
if !exec.usedLight {
|
||||
t.Fatal("expected light routing to be used")
|
||||
}
|
||||
if exec.llmModelName != "light-model" {
|
||||
t.Fatalf("exec.llmModelName = %q, want %q", exec.llmModelName, "light-model")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunTurn_FinalizeSaveErrorEmitsErrorTurnEnd(t *testing.T) {
|
||||
al, agent, cleanup := newTurnCoordTestLoop(t, &simpleConvProvider{})
|
||||
defer cleanup()
|
||||
|
||||
@@ -84,6 +84,7 @@ const (
|
||||
|
||||
type turnResult struct {
|
||||
finalContent string
|
||||
modelName string
|
||||
status TurnEndStatus
|
||||
followUps []bus.InboundMessage
|
||||
}
|
||||
@@ -140,6 +141,7 @@ type turnExecution struct {
|
||||
callMessages []providers.Message
|
||||
providerToolDefs []providers.ToolDefinition
|
||||
llmModel string
|
||||
llmModelName string
|
||||
llmOpts map[string]any
|
||||
gracefulTerminal bool
|
||||
useNativeSearch bool
|
||||
|
||||
@@ -20,6 +20,17 @@ type MessageEditor interface {
|
||||
EditMessage(ctx context.Context, chatID string, messageID string, content string) error
|
||||
}
|
||||
|
||||
// MessageEditorWithPayload extends MessageEditor for channels that can update
|
||||
// structured message metadata in addition to plain text content.
|
||||
type MessageEditorWithPayload interface {
|
||||
EditMessageWithPayload(
|
||||
ctx context.Context,
|
||||
chatID string,
|
||||
messageID string,
|
||||
payload map[string]any,
|
||||
) error
|
||||
}
|
||||
|
||||
// MessageDeleter — channels that can delete a message by ID.
|
||||
type MessageDeleter interface {
|
||||
DeleteMessage(ctx context.Context, chatID string, messageID string) error
|
||||
|
||||
+63
-2
@@ -191,6 +191,19 @@ func outboundMessageBypassesPlaceholderEdit(msg bus.OutboundMessage) bool {
|
||||
return strings.EqualFold(kind, "thought") || strings.EqualFold(kind, "tool_calls")
|
||||
}
|
||||
|
||||
func outboundMessageEditPayload(msg bus.OutboundMessage, content string) map[string]any {
|
||||
payload := map[string]any{
|
||||
"content": content,
|
||||
}
|
||||
if len(msg.Context.Raw) == 0 {
|
||||
return payload
|
||||
}
|
||||
if modelName := strings.TrimSpace(msg.Context.Raw["model_name"]); modelName != "" {
|
||||
payload["model_name"] = modelName
|
||||
}
|
||||
return payload
|
||||
}
|
||||
|
||||
func outboundMediaChannel(msg bus.OutboundMediaMessage) string {
|
||||
return msg.Context.Channel
|
||||
}
|
||||
@@ -394,7 +407,16 @@ func (m *Manager) preSend(ctx context.Context, name string, msg bus.OutboundMess
|
||||
if deleter, ok := ch.(MessageDeleter); ok {
|
||||
deleter.DeleteMessage(ctx, chatID, entry.id) // best effort
|
||||
} else if editor, ok := ch.(MessageEditor); ok {
|
||||
editor.EditMessage(ctx, chatID, entry.id, msg.Content) // fallback
|
||||
if payloadEditor, ok := ch.(MessageEditorWithPayload); ok {
|
||||
_ = payloadEditor.EditMessageWithPayload(
|
||||
ctx,
|
||||
chatID,
|
||||
entry.id,
|
||||
outboundMessageEditPayload(msg, msg.Content),
|
||||
)
|
||||
} else {
|
||||
editor.EditMessage(ctx, chatID, entry.id, msg.Content) // fallback
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -446,7 +468,18 @@ func (m *Manager) preSend(ctx context.Context, name string, msg bus.OutboundMess
|
||||
trackedContent = prepareToolFeedbackMessageContent(ch, msg.Content)
|
||||
content = InitialAnimatedToolFeedbackContent(trackedContent)
|
||||
}
|
||||
if err := editor.EditMessage(ctx, chatID, entry.id, content); err == nil {
|
||||
err := func() error {
|
||||
if payloadEditor, ok := ch.(MessageEditorWithPayload); ok {
|
||||
return payloadEditor.EditMessageWithPayload(
|
||||
ctx,
|
||||
chatID,
|
||||
entry.id,
|
||||
outboundMessageEditPayload(msg, content),
|
||||
)
|
||||
}
|
||||
return editor.EditMessage(ctx, chatID, entry.id, content)
|
||||
}()
|
||||
if err == nil {
|
||||
trackedChatID := trackedToolFeedbackMessageChatID(ch, chatID, &msg.Context)
|
||||
if tracker, ok := ch.(toolFeedbackMessageTracker); ok && isToolFeedback {
|
||||
tracker.RecordToolFeedbackMessage(trackedChatID, entry.id, trackedContent)
|
||||
@@ -643,6 +676,18 @@ func reasoningStreamerFrom(streamer bus.Streamer) bus.ReasoningStreamer {
|
||||
return nil
|
||||
}
|
||||
|
||||
type modelNameStreamer interface {
|
||||
SetModelName(modelName string)
|
||||
}
|
||||
|
||||
func setStreamerModelName(streamer any, modelName string) {
|
||||
setter, ok := streamer.(modelNameStreamer)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
setter.SetModelName(modelName)
|
||||
}
|
||||
|
||||
// splitMarkerStreamer turns accumulated streaming text containing
|
||||
// MessageSplitMarker into separate channel stream messages.
|
||||
type splitMarkerStreamer struct {
|
||||
@@ -654,6 +699,7 @@ type splitMarkerStreamer struct {
|
||||
finalized bool
|
||||
onFinalize func(context.Context, string)
|
||||
clearMarker func()
|
||||
modelName string
|
||||
}
|
||||
|
||||
func (s *splitMarkerStreamer) Update(ctx context.Context, content string) error {
|
||||
@@ -682,6 +728,7 @@ func (s *splitMarkerStreamer) UpdateReasoning(ctx context.Context, content strin
|
||||
if s.reasoning == nil {
|
||||
return nil
|
||||
}
|
||||
setStreamerModelName(s.reasoning, s.modelName)
|
||||
return s.reasoning.UpdateReasoning(ctx, content)
|
||||
}
|
||||
|
||||
@@ -691,9 +738,18 @@ func (s *splitMarkerStreamer) FinalizeReasoning(ctx context.Context, content str
|
||||
if s.reasoning == nil {
|
||||
return nil
|
||||
}
|
||||
setStreamerModelName(s.reasoning, s.modelName)
|
||||
return s.reasoning.FinalizeReasoning(ctx, content)
|
||||
}
|
||||
|
||||
func (s *splitMarkerStreamer) SetModelName(modelName string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.modelName = strings.TrimSpace(modelName)
|
||||
setStreamerModelName(s.current, s.modelName)
|
||||
setStreamerModelName(s.reasoning, s.modelName)
|
||||
}
|
||||
|
||||
func (s *splitMarkerStreamer) Cancel(ctx context.Context) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
@@ -772,6 +828,7 @@ func (s *splitMarkerStreamer) ensureCurrentLocked(ctx context.Context) error {
|
||||
return err
|
||||
}
|
||||
s.current = streamer
|
||||
setStreamerModelName(s.current, s.modelName)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -856,6 +913,10 @@ func (s *finalizeHookStreamer) FinalizeReasoning(ctx context.Context, content st
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *finalizeHookStreamer) SetModelName(modelName string) {
|
||||
setStreamerModelName(s.Streamer, strings.TrimSpace(modelName))
|
||||
}
|
||||
|
||||
func (s *finalizeHookStreamer) runFinalizeHook(ctx context.Context, content string) {
|
||||
if s.onFinalize != nil {
|
||||
s.onFinalize(ctx, content)
|
||||
|
||||
@@ -142,11 +142,21 @@ func (m *mockReasoningStreamer) FinalizeReasoning(_ context.Context, content str
|
||||
return nil
|
||||
}
|
||||
|
||||
type modelTrackingReasoningStreamer struct {
|
||||
mockReasoningStreamer
|
||||
modelNames []string
|
||||
}
|
||||
|
||||
func (m *modelTrackingReasoningStreamer) SetModelName(modelName string) {
|
||||
m.modelNames = append(m.modelNames, strings.TrimSpace(modelName))
|
||||
}
|
||||
|
||||
type recordingStreamSegment struct {
|
||||
updates []string
|
||||
finals []string
|
||||
finalUsage *bus.ContextUsage
|
||||
canceledCount int
|
||||
modelNames []string
|
||||
}
|
||||
|
||||
func (s *recordingStreamSegment) Update(_ context.Context, content string) error {
|
||||
@@ -168,6 +178,10 @@ func (s *recordingStreamSegment) Cancel(context.Context) {
|
||||
s.canceledCount++
|
||||
}
|
||||
|
||||
func (s *recordingStreamSegment) SetModelName(modelName string) {
|
||||
s.modelNames = append(s.modelNames, strings.TrimSpace(modelName))
|
||||
}
|
||||
|
||||
type mockStreamingChannel struct {
|
||||
mockMessageEditor
|
||||
streamer Streamer
|
||||
@@ -2068,6 +2082,42 @@ func TestGetStreamer_PreservesReasoningStreamer(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetStreamer_PreservesModelNameSetter(t *testing.T) {
|
||||
m := newTestManager()
|
||||
inner := &modelTrackingReasoningStreamer{}
|
||||
ch := &mockStreamingChannel{
|
||||
streamer: inner,
|
||||
}
|
||||
m.channels["test"] = ch
|
||||
|
||||
streamer, ok := m.GetStreamer(context.Background(), "test", "123", "")
|
||||
if !ok {
|
||||
t.Fatal("expected streamer to be available")
|
||||
}
|
||||
setter, ok := streamer.(interface{ SetModelName(modelName string) })
|
||||
if !ok {
|
||||
t.Fatal("manager-wrapped streamer should preserve SetModelName")
|
||||
}
|
||||
setter.SetModelName("gpt-5.4")
|
||||
if err := streamer.Update(context.Background(), "hello"); err != nil {
|
||||
t.Fatalf("Update() error = %v", err)
|
||||
}
|
||||
reasoningStreamer, ok := streamer.(bus.ReasoningStreamer)
|
||||
if !ok {
|
||||
t.Fatal("manager-wrapped streamer should preserve ReasoningStreamer")
|
||||
}
|
||||
setter.SetModelName("gpt-5.4")
|
||||
if err := reasoningStreamer.UpdateReasoning(context.Background(), "thinking"); err != nil {
|
||||
t.Fatalf("UpdateReasoning() error = %v", err)
|
||||
}
|
||||
if len(inner.modelNames) != 2 {
|
||||
t.Fatalf("model name calls = %v, want 2 forwarded calls", inner.modelNames)
|
||||
}
|
||||
if inner.modelNames[0] != "gpt-5.4" || inner.modelNames[1] != "gpt-5.4" {
|
||||
t.Fatalf("model name calls = %v, want both forwarded as gpt-5.4", inner.modelNames)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetStreamer_SplitOnMarkerStreamsSeparateSegments(t *testing.T) {
|
||||
m := newTestManager()
|
||||
m.config = &config.Config{
|
||||
@@ -2188,6 +2238,58 @@ func TestGetStreamer_SplitOnMarkerKeepsReasoningOnInitialStreamer(t *testing.T)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetStreamer_SplitOnMarkerPreservesModelNameSetter(t *testing.T) {
|
||||
m := newTestManager()
|
||||
m.config = &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
SplitOnMarker: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
initial := &modelTrackingReasoningStreamer{}
|
||||
next := &recordingStreamSegment{}
|
||||
callCount := 0
|
||||
ch := &mockStreamingChannel{
|
||||
beginStreamFn: func(context.Context, string) (Streamer, error) {
|
||||
callCount++
|
||||
if callCount == 1 {
|
||||
return initial, nil
|
||||
}
|
||||
return next, nil
|
||||
},
|
||||
}
|
||||
m.channels["test"] = ch
|
||||
|
||||
streamer, ok := m.GetStreamer(context.Background(), "test", "123", "")
|
||||
if !ok {
|
||||
t.Fatal("expected streamer to be available")
|
||||
}
|
||||
setter, ok := streamer.(interface{ SetModelName(modelName string) })
|
||||
if !ok {
|
||||
t.Fatal("split streamer should preserve SetModelName")
|
||||
}
|
||||
setter.SetModelName("gpt-5.4-mini")
|
||||
if err := streamer.Update(context.Background(), "hello<|[SPLIT]|>world"); err != nil {
|
||||
t.Fatalf("Update() error = %v", err)
|
||||
}
|
||||
reasoningStreamer, ok := streamer.(bus.ReasoningStreamer)
|
||||
if !ok {
|
||||
t.Fatal("split streamer should preserve ReasoningStreamer")
|
||||
}
|
||||
if err := reasoningStreamer.UpdateReasoning(context.Background(), "thinking"); err != nil {
|
||||
t.Fatalf("UpdateReasoning() error = %v", err)
|
||||
}
|
||||
|
||||
if len(initial.modelNames) == 0 || initial.modelNames[0] != "gpt-5.4-mini" {
|
||||
t.Fatalf("initial model names = %v, want forwarded gpt-5.4-mini", initial.modelNames)
|
||||
}
|
||||
if len(next.modelNames) == 0 || next.modelNames[0] != "gpt-5.4-mini" {
|
||||
t.Fatalf("next model names = %v, want forwarded gpt-5.4-mini", next.modelNames)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetStreamer_FinalizeSeparateMessagesClearsTrackedToolFeedback(t *testing.T) {
|
||||
m := newTestManager()
|
||||
m.config = &config.Config{
|
||||
|
||||
@@ -325,6 +325,9 @@ func (c *PicoChannel) Send(ctx context.Context, msg bus.OutboundMessage) ([]stri
|
||||
PayloadKeyContent: content,
|
||||
"message_id": msgID,
|
||||
}
|
||||
if modelName := strings.TrimSpace(msg.Context.Raw[PayloadKeyModelName]); modelName != "" {
|
||||
payload[PayloadKeyModelName] = modelName
|
||||
}
|
||||
switch {
|
||||
case isThought:
|
||||
payload[PayloadKeyKind] = MessageKindThought
|
||||
@@ -359,6 +362,15 @@ func (c *PicoChannel) EditMessage(ctx context.Context, chatID string, messageID
|
||||
return c.editMessage(ctx, chatID, messageID, content, nil)
|
||||
}
|
||||
|
||||
func (c *PicoChannel) EditMessageWithPayload(
|
||||
ctx context.Context,
|
||||
chatID string,
|
||||
messageID string,
|
||||
payload map[string]any,
|
||||
) error {
|
||||
return c.editMessagePayload(ctx, chatID, messageID, payload, nil)
|
||||
}
|
||||
|
||||
// DeleteMessage implements channels.MessageDeleter.
|
||||
func (c *PicoChannel) DeleteMessage(ctx context.Context, chatID string, messageID string) error {
|
||||
outMsg := newMessage(TypeMessageDelete, map[string]any{
|
||||
@@ -419,14 +431,23 @@ func (c *PicoChannel) finalizeTrackedToolFeedbackMessage(
|
||||
ctx context.Context,
|
||||
chatID string,
|
||||
content string,
|
||||
editFn func(context.Context, string, string, string, *bus.ContextUsage) error,
|
||||
editFn func(context.Context, string, string, map[string]any, *bus.ContextUsage) error,
|
||||
payload map[string]any,
|
||||
contextUsage *bus.ContextUsage,
|
||||
) ([]string, bool) {
|
||||
msgID, baseContent, ok := c.takeToolFeedbackMessage(chatID)
|
||||
if !ok || editFn == nil {
|
||||
return nil, false
|
||||
}
|
||||
if err := editFn(ctx, chatID, msgID, content, contextUsage); err != nil {
|
||||
if payload == nil {
|
||||
payload = map[string]any{
|
||||
PayloadKeyContent: content,
|
||||
}
|
||||
}
|
||||
if _, ok := payload[PayloadKeyContent]; !ok {
|
||||
payload[PayloadKeyContent] = content
|
||||
}
|
||||
if err := editFn(ctx, chatID, msgID, payload, contextUsage); err != nil {
|
||||
c.RecordToolFeedbackMessage(chatID, msgID, baseContent)
|
||||
return nil, false
|
||||
}
|
||||
@@ -437,7 +458,20 @@ func (c *PicoChannel) FinalizeToolFeedbackMessage(ctx context.Context, msg bus.O
|
||||
if !outboundMessageFinalizesTrackedToolFeedback(msg) {
|
||||
return nil, false
|
||||
}
|
||||
return c.finalizeTrackedToolFeedbackMessage(ctx, msg.ChatID, msg.Content, c.editMessage, msg.ContextUsage)
|
||||
payload := map[string]any{
|
||||
PayloadKeyContent: msg.Content,
|
||||
}
|
||||
if modelName := strings.TrimSpace(msg.Context.Raw[PayloadKeyModelName]); modelName != "" {
|
||||
payload[PayloadKeyModelName] = modelName
|
||||
}
|
||||
return c.finalizeTrackedToolFeedbackMessage(
|
||||
ctx,
|
||||
msg.ChatID,
|
||||
msg.Content,
|
||||
c.editMessagePayload,
|
||||
payload,
|
||||
msg.ContextUsage,
|
||||
)
|
||||
}
|
||||
|
||||
// StartTyping implements channels.TypingCapable.
|
||||
@@ -496,6 +530,7 @@ func (c *PicoChannel) BeginStream(ctx context.Context, chatID string) (channels.
|
||||
type picoStreamer struct {
|
||||
channel *PicoChannel
|
||||
chatID string
|
||||
modelName string
|
||||
messageID string
|
||||
reasoningID string
|
||||
throttleInterval time.Duration
|
||||
@@ -509,6 +544,15 @@ type picoStreamer struct {
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func (s *picoStreamer) SetModelName(modelName string) {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.modelName = strings.TrimSpace(modelName)
|
||||
}
|
||||
|
||||
func (s *picoStreamer) Update(ctx context.Context, content string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
@@ -613,13 +657,23 @@ func (s *picoStreamer) sendLocked(ctx context.Context, content string, contextUs
|
||||
PayloadKeyContent: content,
|
||||
"message_id": s.messageID,
|
||||
}
|
||||
if s.modelName != "" {
|
||||
payload[PayloadKeyModelName] = s.modelName
|
||||
}
|
||||
setContextUsagePayload(payload, contextUsage)
|
||||
outMsg := newMessage(TypeMessageCreate, payload)
|
||||
if err := s.channel.broadcastToSession(s.chatID, outMsg); err != nil {
|
||||
return err
|
||||
}
|
||||
} else if content != s.lastContent || contextUsage != nil {
|
||||
if err := s.channel.editMessage(ctx, s.chatID, s.messageID, content, contextUsage); err != nil {
|
||||
payload := map[string]any{
|
||||
PayloadKeyContent: content,
|
||||
"message_id": s.messageID,
|
||||
}
|
||||
if s.modelName != "" {
|
||||
payload[PayloadKeyModelName] = s.modelName
|
||||
}
|
||||
if err := s.channel.editMessagePayload(ctx, s.chatID, s.messageID, payload, contextUsage); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -642,6 +696,9 @@ func (s *picoStreamer) sendReasoningLocked(ctx context.Context, content string)
|
||||
PayloadKeyKind: MessageKindThought,
|
||||
PayloadKeyThought: true,
|
||||
}
|
||||
if s.modelName != "" {
|
||||
payload[PayloadKeyModelName] = s.modelName
|
||||
}
|
||||
outMsg := newMessage(TypeMessageCreate, payload)
|
||||
if err := s.channel.broadcastToSession(s.chatID, outMsg); err != nil {
|
||||
return err
|
||||
@@ -653,6 +710,9 @@ func (s *picoStreamer) sendReasoningLocked(ctx context.Context, content string)
|
||||
PayloadKeyKind: MessageKindThought,
|
||||
PayloadKeyThought: true,
|
||||
}
|
||||
if s.modelName != "" {
|
||||
payload[PayloadKeyModelName] = s.modelName
|
||||
}
|
||||
outMsg := newMessage(TypeMessageUpdate, payload)
|
||||
if err := s.channel.broadcastToSession(s.chatID, outMsg); err != nil {
|
||||
return err
|
||||
@@ -744,6 +804,9 @@ func (c *PicoChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessag
|
||||
"attachments": attachments,
|
||||
"message_id": msgID,
|
||||
})
|
||||
if modelName := strings.TrimSpace(msg.Context.Raw[PayloadKeyModelName]); modelName != "" {
|
||||
outMsg.Payload[PayloadKeyModelName] = modelName
|
||||
}
|
||||
|
||||
if err := c.broadcastToSession(msg.ChatID, outMsg); err != nil {
|
||||
return nil, err
|
||||
@@ -1358,11 +1421,30 @@ func (c *PicoChannel) editMessage(
|
||||
content string,
|
||||
contextUsage *bus.ContextUsage,
|
||||
) error {
|
||||
payload := map[string]any{
|
||||
"message_id": messageID,
|
||||
"content": content,
|
||||
return c.editMessagePayload(ctx, chatID, messageID, map[string]any{
|
||||
PayloadKeyContent: content,
|
||||
}, contextUsage)
|
||||
}
|
||||
|
||||
func (c *PicoChannel) editMessagePayload(
|
||||
ctx context.Context,
|
||||
chatID string,
|
||||
messageID string,
|
||||
payload map[string]any,
|
||||
contextUsage *bus.ContextUsage,
|
||||
) error {
|
||||
if payload == nil {
|
||||
payload = map[string]any{}
|
||||
}
|
||||
setContextUsagePayload(payload, contextUsage)
|
||||
outMsg := newMessage(TypeMessageUpdate, payload)
|
||||
normalized := make(map[string]any, len(payload)+1)
|
||||
for key, value := range payload {
|
||||
normalized[key] = value
|
||||
}
|
||||
if _, ok := normalized[PayloadKeyContent]; !ok {
|
||||
normalized[PayloadKeyContent] = ""
|
||||
}
|
||||
normalized["message_id"] = messageID
|
||||
setContextUsagePayload(normalized, contextUsage)
|
||||
outMsg := newMessage(TypeMessageUpdate, normalized)
|
||||
return c.broadcastToSession(chatID, outMsg)
|
||||
}
|
||||
|
||||
@@ -46,12 +46,15 @@ func TestFinalizeTrackedToolFeedbackMessage_StopsTrackingBeforeEdit(t *testing.T
|
||||
context.Background(),
|
||||
"pico:chat-1",
|
||||
"final reply",
|
||||
func(_ context.Context, chatID, messageID, content string, contextUsage *bus.ContextUsage) error {
|
||||
func(_ context.Context, chatID, messageID string, payload map[string]any, contextUsage *bus.ContextUsage) error {
|
||||
if _, ok := ch.currentToolFeedbackMessage(chatID); ok {
|
||||
t.Fatal("expected tracked tool feedback to be stopped before edit")
|
||||
}
|
||||
if chatID != "pico:chat-1" || messageID != "msg-1" || content != "final reply" {
|
||||
t.Fatalf("unexpected edit args: %s %s %s", chatID, messageID, content)
|
||||
if chatID != "pico:chat-1" || messageID != "msg-1" {
|
||||
t.Fatalf("unexpected edit args: %s %s", chatID, messageID)
|
||||
}
|
||||
if got := payload[PayloadKeyContent]; got != "final reply" {
|
||||
t.Fatalf("unexpected content payload: %#v", got)
|
||||
}
|
||||
if contextUsage != nil {
|
||||
t.Fatalf("unexpected context usage: %+v", contextUsage)
|
||||
@@ -59,6 +62,7 @@ func TestFinalizeTrackedToolFeedbackMessage_StopsTrackingBeforeEdit(t *testing.T
|
||||
return nil
|
||||
},
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
if !handled {
|
||||
t.Fatal("expected finalizeTrackedToolFeedbackMessage to handle tracked message")
|
||||
@@ -115,7 +119,8 @@ func TestSend_ThoughtMessageDoesNotFinalizeTrackedToolFeedback(t *testing.T) {
|
||||
Channel: "pico",
|
||||
ChatID: "pico:sess-1",
|
||||
Raw: map[string]string{
|
||||
"message_kind": MessageKindThought,
|
||||
"message_kind": MessageKindThought,
|
||||
PayloadKeyModelName: "gpt-5.4-mini",
|
||||
},
|
||||
},
|
||||
}); err != nil {
|
||||
@@ -134,6 +139,9 @@ func TestSend_ThoughtMessageDoesNotFinalizeTrackedToolFeedback(t *testing.T) {
|
||||
if got := payload[PayloadKeyKind]; got != MessageKindThought {
|
||||
t.Fatalf("thought kind = %#v, want %q", got, MessageKindThought)
|
||||
}
|
||||
if got := payload[PayloadKeyModelName]; got != "gpt-5.4-mini" {
|
||||
t.Fatalf("thought model_name = %#v, want %q", got, "gpt-5.4-mini")
|
||||
}
|
||||
if got := payload["message_id"]; got == "msg-progress" || got == nil || got == "" {
|
||||
t.Fatalf("thought message_id = %#v, want new non-progress id", got)
|
||||
}
|
||||
@@ -151,6 +159,9 @@ func TestSend_ThoughtMessageDoesNotFinalizeTrackedToolFeedback(t *testing.T) {
|
||||
Context: bus.InboundContext{
|
||||
Channel: "pico",
|
||||
ChatID: "pico:sess-1",
|
||||
Raw: map[string]string{
|
||||
PayloadKeyModelName: "gpt-5.4",
|
||||
},
|
||||
},
|
||||
ContextUsage: &bus.ContextUsage{
|
||||
UsedTokens: 321,
|
||||
@@ -174,6 +185,9 @@ func TestSend_ThoughtMessageDoesNotFinalizeTrackedToolFeedback(t *testing.T) {
|
||||
if got := payload[PayloadKeyContent]; got != "final reply" {
|
||||
t.Fatalf("final content = %#v, want %q", got, "final reply")
|
||||
}
|
||||
if got := payload[PayloadKeyModelName]; got != "gpt-5.4" {
|
||||
t.Fatalf("final model_name = %#v, want %q", got, "gpt-5.4")
|
||||
}
|
||||
rawUsage, ok := payload["context_usage"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("final context_usage = %#v, want map payload", payload["context_usage"])
|
||||
@@ -193,6 +207,54 @@ func TestSend_ThoughtMessageDoesNotFinalizeTrackedToolFeedback(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSend_ToolCallsMessageIncludesModelName(t *testing.T) {
|
||||
ch := newTestPicoChannel(t)
|
||||
|
||||
if err := ch.Start(context.Background()); err != nil {
|
||||
t.Fatalf("Start() error = %v", err)
|
||||
}
|
||||
defer ch.Stop(context.Background())
|
||||
|
||||
clientConn, received, cleanup := newTestPicoWebSocket(t)
|
||||
defer cleanup()
|
||||
ch.addConnForTest(&picoConn{id: "conn-1", conn: clientConn, sessionID: "sess-1"})
|
||||
|
||||
if _, err := ch.Send(context.Background(), bus.OutboundMessage{
|
||||
ChatID: "pico:sess-1",
|
||||
Content: "",
|
||||
Context: bus.InboundContext{
|
||||
Channel: "pico",
|
||||
ChatID: "pico:sess-1",
|
||||
Raw: map[string]string{
|
||||
"message_kind": MessageKindToolCalls,
|
||||
PayloadKeyModelName: "gpt-5.4",
|
||||
PayloadKeyToolCalls: `[{"id":"call_1","type":"function","function":{"name":"read_file","arguments":"{\"path\":\"README.md\"}"}}]`,
|
||||
},
|
||||
},
|
||||
}); err != nil {
|
||||
t.Fatalf("Send(tool_calls) error = %v", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case msg := <-received:
|
||||
if msg.Type != TypeMessageCreate {
|
||||
t.Fatalf("tool_calls message type = %q, want %q", msg.Type, TypeMessageCreate)
|
||||
}
|
||||
payload := msg.Payload
|
||||
if got := payload[PayloadKeyKind]; got != MessageKindToolCalls {
|
||||
t.Fatalf("tool_calls kind = %#v, want %q", got, MessageKindToolCalls)
|
||||
}
|
||||
if got := payload[PayloadKeyModelName]; got != "gpt-5.4" {
|
||||
t.Fatalf("tool_calls model_name = %#v, want %q", got, "gpt-5.4")
|
||||
}
|
||||
if _, ok := payload[PayloadKeyToolCalls].([]any); !ok {
|
||||
t.Fatalf("tool_calls payload = %#v, want parsed array", payload[PayloadKeyToolCalls])
|
||||
}
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("expected tool_calls message to be delivered")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendPlaceholder_EmitsNormalMessageWithoutKind(t *testing.T) {
|
||||
ch := newTestPicoChannel(t)
|
||||
ch.bc.Placeholder.Enabled = true
|
||||
@@ -257,6 +319,9 @@ func TestBeginStream_CreatesAndUpdatesSameMessage(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("BeginStream() error = %v", err)
|
||||
}
|
||||
if setter, ok := streamer.(interface{ SetModelName(modelName string) }); ok {
|
||||
setter.SetModelName("gpt-5.4")
|
||||
}
|
||||
if err := streamer.Update(context.Background(), "hello"); err != nil {
|
||||
t.Fatalf("Update(first) error = %v", err)
|
||||
}
|
||||
@@ -271,6 +336,9 @@ func TestBeginStream_CreatesAndUpdatesSameMessage(t *testing.T) {
|
||||
if got := first.Payload[PayloadKeyContent]; got != "hello" {
|
||||
t.Fatalf("first content = %#v, want hello", got)
|
||||
}
|
||||
if got := first.Payload[PayloadKeyModelName]; got != "gpt-5.4" {
|
||||
t.Fatalf("first model_name = %#v, want %q", got, "gpt-5.4")
|
||||
}
|
||||
|
||||
rawStreamer := streamer.(*picoStreamer)
|
||||
rawStreamer.mu.Lock()
|
||||
@@ -290,6 +358,9 @@ func TestBeginStream_CreatesAndUpdatesSameMessage(t *testing.T) {
|
||||
if got := second.Payload[PayloadKeyContent]; got != secondContent {
|
||||
t.Fatalf("second content = %#v, want %q", got, secondContent)
|
||||
}
|
||||
if got := second.Payload[PayloadKeyModelName]; got != "gpt-5.4" {
|
||||
t.Fatalf("second model_name = %#v, want %q", got, "gpt-5.4")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBeginStream_DefaultStreamingShowsSmallIncrements(t *testing.T) {
|
||||
@@ -355,6 +426,9 @@ func TestBeginStream_StreamsReasoningAsThoughtUpdates(t *testing.T) {
|
||||
if !ok {
|
||||
t.Fatal("pico stream should support reasoning updates")
|
||||
}
|
||||
if setter, ok := streamer.(interface{ SetModelName(modelName string) }); ok {
|
||||
setter.SetModelName("gpt-5.4-mini")
|
||||
}
|
||||
if err := reasoningStreamer.UpdateReasoning(context.Background(), "thinking"); err != nil {
|
||||
t.Fatalf("UpdateReasoning(first) error = %v", err)
|
||||
}
|
||||
@@ -372,6 +446,9 @@ func TestBeginStream_StreamsReasoningAsThoughtUpdates(t *testing.T) {
|
||||
if got := first.Payload[PayloadKeyContent]; got != "thinking" {
|
||||
t.Fatalf("first content = %#v, want thinking", got)
|
||||
}
|
||||
if got := first.Payload[PayloadKeyModelName]; got != "gpt-5.4-mini" {
|
||||
t.Fatalf("first model_name = %#v, want %q", got, "gpt-5.4-mini")
|
||||
}
|
||||
|
||||
if err := reasoningStreamer.UpdateReasoning(context.Background(), "thinking more"); err != nil {
|
||||
t.Fatalf("UpdateReasoning(second) error = %v", err)
|
||||
@@ -389,6 +466,9 @@ func TestBeginStream_StreamsReasoningAsThoughtUpdates(t *testing.T) {
|
||||
if got := second.Payload[PayloadKeyContent]; got != "thinking more" {
|
||||
t.Fatalf("second content = %#v, want thinking more", got)
|
||||
}
|
||||
if got := second.Payload[PayloadKeyModelName]; got != "gpt-5.4-mini" {
|
||||
t.Fatalf("second model_name = %#v, want %q", got, "gpt-5.4-mini")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBeginStream_ThrottlesIntermediateUpdatesAndFinalFlushes(t *testing.T) {
|
||||
@@ -473,6 +553,9 @@ func TestBeginStream_FinalizeIncludesContextUsage(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("BeginStream() error = %v", err)
|
||||
}
|
||||
if setter, ok := streamer.(interface{ SetModelName(modelName string) }); ok {
|
||||
setter.SetModelName("gpt-5.4")
|
||||
}
|
||||
if err := streamer.Update(context.Background(), "partial"); err != nil {
|
||||
t.Fatalf("Update() error = %v", err)
|
||||
}
|
||||
@@ -501,6 +584,9 @@ func TestBeginStream_FinalizeIncludesContextUsage(t *testing.T) {
|
||||
if got := final.Payload["message_id"]; got != msgID {
|
||||
t.Fatalf("final message_id = %#v, want %q", got, msgID)
|
||||
}
|
||||
if got := final.Payload[PayloadKeyModelName]; got != "gpt-5.4" {
|
||||
t.Fatalf("final model_name = %#v, want %q", got, "gpt-5.4")
|
||||
}
|
||||
rawUsage, ok := final.Payload["context_usage"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("final context_usage = %#v, want map", final.Payload["context_usage"])
|
||||
|
||||
@@ -27,6 +27,7 @@ const (
|
||||
PayloadKeyKind = "kind"
|
||||
PayloadKeyPlaceholder = "placeholder"
|
||||
PayloadKeyToolCalls = "tool_calls"
|
||||
PayloadKeyModelName = "model_name"
|
||||
|
||||
MessageKindThought = "thought"
|
||||
MessageKindToolCalls = "tool_calls"
|
||||
|
||||
@@ -130,6 +130,32 @@ func TestAddFullMessage_WithToolCalls(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddFullMessage_PreservesModelName(t *testing.T) {
|
||||
store := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
msg := providers.Message{
|
||||
Role: "assistant",
|
||||
Content: "done",
|
||||
ModelName: "gpt-5.4-mini",
|
||||
}
|
||||
|
||||
if err := store.AddFullMessage(ctx, "model-name", msg); err != nil {
|
||||
t.Fatalf("AddFullMessage: %v", err)
|
||||
}
|
||||
|
||||
history, err := store.GetHistory(ctx, "model-name")
|
||||
if err != nil {
|
||||
t.Fatalf("GetHistory: %v", err)
|
||||
}
|
||||
if len(history) != 1 {
|
||||
t.Fatalf("expected 1, got %d", len(history))
|
||||
}
|
||||
if history[0].ModelName != "gpt-5.4-mini" {
|
||||
t.Fatalf("ModelName = %q, want %q", history[0].ModelName, "gpt-5.4-mini")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddFullMessage_ToolCallID(t *testing.T) {
|
||||
store := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
@@ -17,6 +17,7 @@ type FallbackChain struct {
|
||||
type FallbackCandidate struct {
|
||||
Provider string
|
||||
Model string
|
||||
DisplayName string // optional configured alias/raw model label for persistence/UI
|
||||
RPM int // requests per minute; 0 means unrestricted
|
||||
IdentityKey string // optional stable config identity for cooldown/rate limiting
|
||||
}
|
||||
@@ -32,10 +33,11 @@ func (c FallbackCandidate) StableKey() string {
|
||||
|
||||
// FallbackResult contains the successful response and metadata about all attempts.
|
||||
type FallbackResult struct {
|
||||
Response *LLMResponse
|
||||
Provider string
|
||||
Model string
|
||||
Attempts []FallbackAttempt
|
||||
Response *LLMResponse
|
||||
Provider string
|
||||
Model string
|
||||
IdentityKey string
|
||||
Attempts []FallbackAttempt
|
||||
}
|
||||
|
||||
// FallbackAttempt records one attempt in the fallback chain.
|
||||
@@ -85,8 +87,9 @@ func ResolveCandidatesWithLookup(
|
||||
}
|
||||
seen[key] = true
|
||||
candidates = append(candidates, FallbackCandidate{
|
||||
Provider: ref.Provider,
|
||||
Model: ref.Model,
|
||||
Provider: ref.Provider,
|
||||
Model: ref.Model,
|
||||
DisplayName: candidateRaw,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -187,6 +190,7 @@ func (fc *FallbackChain) Execute(
|
||||
result.Response = resp
|
||||
result.Provider = candidate.Provider
|
||||
result.Model = candidate.Model
|
||||
result.IdentityKey = candidate.StableKey()
|
||||
return result, nil
|
||||
}
|
||||
|
||||
@@ -305,6 +309,7 @@ func (fc *FallbackChain) ExecuteImage(
|
||||
result.Response = resp
|
||||
result.Provider = candidate.Provider
|
||||
result.Model = candidate.Model
|
||||
result.IdentityKey = candidate.StableKey()
|
||||
return result, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -86,6 +86,7 @@ type Attachment struct {
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
ModelName string `json:"model_name,omitempty"`
|
||||
Media []string `json:"media,omitempty"`
|
||||
Attachments []Attachment `json:"attachments,omitempty"`
|
||||
ReasoningContent string `json:"reasoning_content,omitempty"`
|
||||
|
||||
@@ -46,6 +46,7 @@ func runSchema(db *sql.DB) error {
|
||||
conversation_id INTEGER NOT NULL REFERENCES conversations(conversation_id),
|
||||
role TEXT NOT NULL,
|
||||
content TEXT NOT NULL DEFAULT '',
|
||||
model_name TEXT NOT NULL DEFAULT '',
|
||||
reasoning_content TEXT NOT NULL DEFAULT '',
|
||||
token_count INTEGER NOT NULL DEFAULT 0,
|
||||
created_at TEXT NOT NULL DEFAULT (datetime('now'))
|
||||
@@ -162,6 +163,9 @@ func runSchema(db *sql.DB) error {
|
||||
if err := ensureMessagesReasoningContentColumn(db); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := ensureMessagesModelNameColumn(db); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -180,6 +184,21 @@ func ensureMessagesReasoningContentColumn(db *sql.DB) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func ensureMessagesModelNameColumn(db *sql.DB) error {
|
||||
hasColumn, err := tableHasColumn(db, "messages", "model_name")
|
||||
if err != nil {
|
||||
return fmt.Errorf("check messages.model_name: %w", err)
|
||||
}
|
||||
if hasColumn {
|
||||
return nil
|
||||
}
|
||||
|
||||
if _, err := db.Exec(`ALTER TABLE messages ADD COLUMN model_name TEXT NOT NULL DEFAULT ''`); err != nil {
|
||||
return fmt.Errorf("add messages.model_name: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func tableHasColumn(db *sql.DB, tableName, columnName string) (bool, error) {
|
||||
rows, err := db.Query(fmt.Sprintf(`PRAGMA table_info(%s)`, tableName))
|
||||
if err != nil {
|
||||
|
||||
@@ -138,6 +138,37 @@ func TestRunSchemaAddsMessagesReasoningContentColumn(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunSchemaAddsMessagesModelNameColumn(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
|
||||
_, err := db.Exec(`CREATE TABLE messages (
|
||||
message_id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
conversation_id INTEGER NOT NULL,
|
||||
role TEXT NOT NULL,
|
||||
content TEXT NOT NULL DEFAULT '',
|
||||
reasoning_content TEXT NOT NULL DEFAULT '',
|
||||
token_count INTEGER NOT NULL DEFAULT 0,
|
||||
created_at TEXT NOT NULL DEFAULT (datetime('now'))
|
||||
)`)
|
||||
if err != nil {
|
||||
t.Fatalf("create legacy messages table: %v", err)
|
||||
}
|
||||
|
||||
err = runSchema(db)
|
||||
if err != nil {
|
||||
t.Fatalf("runSchema: %v", err)
|
||||
}
|
||||
|
||||
var count int
|
||||
err = db.QueryRow(`SELECT count(*) FROM pragma_table_info('messages') WHERE name = 'model_name'`).Scan(&count)
|
||||
if err != nil {
|
||||
t.Fatalf("query pragma_table_info: %v", err)
|
||||
}
|
||||
if count != 1 {
|
||||
t.Fatalf("model_name column count = %d, want 1", count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMigrationConversationUnique(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
if err := runSchema(db); err != nil {
|
||||
|
||||
@@ -258,6 +258,7 @@ func (e *Engine) Ingest(ctx context.Context, sessionKey string, messages []Messa
|
||||
conv.ConversationID,
|
||||
msg.Role,
|
||||
msg.Parts,
|
||||
msg.ModelName,
|
||||
msg.ReasoningContent,
|
||||
msg.TokenCount,
|
||||
)
|
||||
@@ -267,6 +268,7 @@ func (e *Engine) Ingest(ctx context.Context, sessionKey string, messages []Messa
|
||||
conv.ConversationID,
|
||||
msg.Role,
|
||||
msg.Content,
|
||||
msg.ModelName,
|
||||
msg.ReasoningContent,
|
||||
msg.TokenCount,
|
||||
)
|
||||
@@ -431,6 +433,31 @@ func (e *Engine) Bootstrap(ctx context.Context, sessionKey string, messages []Me
|
||||
return fmt.Errorf("bootstrap: get messages: %w", err)
|
||||
}
|
||||
|
||||
// Migration repair path: old SeaHorse rows may be missing reasoning_content
|
||||
// even though the canonical JSONL history already has it. Backfill those
|
||||
// rows in place so we do not treat this as edited history and leave stale
|
||||
// summaries/context behind after a partial raw-message rebuild.
|
||||
repairedReasoning, err := e.repairBootstrapReasoningContent(ctx, dbMsgs, messages)
|
||||
if err != nil {
|
||||
return fmt.Errorf("bootstrap: repair reasoning_content: %w", err)
|
||||
}
|
||||
repairedModelName, err := e.repairBootstrapModelName(ctx, dbMsgs, messages)
|
||||
if err != nil {
|
||||
return fmt.Errorf("bootstrap: repair model_name: %w", err)
|
||||
}
|
||||
if (repairedReasoning || repairedModelName) && len(dbMsgs) == len(messages) {
|
||||
matched := true
|
||||
for i := range messages {
|
||||
if !messageMatches(dbMsgs[i], messages[i]) {
|
||||
matched = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if matched {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Fast path: DB has same count and exact match → no-op
|
||||
if len(dbMsgs) == len(messages) {
|
||||
matched := true
|
||||
@@ -445,16 +472,6 @@ func (e *Engine) Bootstrap(ctx context.Context, sessionKey string, messages []Me
|
||||
}
|
||||
}
|
||||
|
||||
// Migration repair path: old SeaHorse rows may be missing reasoning_content
|
||||
// even though the canonical JSONL history already has it. Backfill those
|
||||
// rows in place so we do not treat this as edited history and leave stale
|
||||
// summaries/context behind after a partial raw-message rebuild.
|
||||
if repaired, err := e.repairBootstrapReasoningContent(ctx, dbMsgs, messages); err != nil {
|
||||
return fmt.Errorf("bootstrap: repair reasoning_content: %w", err)
|
||||
} else if repaired && len(dbMsgs) == len(messages) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Find longest matching prefix from the start
|
||||
anchor := -1
|
||||
compareLen := min(len(dbMsgs), len(messages))
|
||||
@@ -465,14 +482,16 @@ func (e *Engine) Bootstrap(ctx context.Context, sessionKey string, messages []Me
|
||||
} else {
|
||||
// Mismatch detected - log details and rebuild
|
||||
logger.InfoCF("seahorse", "bootstrap: mismatch detected", map[string]any{
|
||||
"conv_id": conv.ConversationID,
|
||||
"index": i,
|
||||
"db_role": dbMsgs[i].Role,
|
||||
"db_content": truncate(dbMsgs[i].Content, 50),
|
||||
"db_parts": len(dbMsgs[i].Parts),
|
||||
"msg_role": messages[i].Role,
|
||||
"msg_content": truncate(messages[i].Content, 50),
|
||||
"msg_parts": len(messages[i].Parts),
|
||||
"conv_id": conv.ConversationID,
|
||||
"index": i,
|
||||
"db_role": dbMsgs[i].Role,
|
||||
"db_content": truncate(dbMsgs[i].Content, 50),
|
||||
"db_parts": len(dbMsgs[i].Parts),
|
||||
"db_model_name": dbMsgs[i].ModelName,
|
||||
"msg_role": messages[i].Role,
|
||||
"msg_content": truncate(messages[i].Content, 50),
|
||||
"msg_parts": len(messages[i].Parts),
|
||||
"msg_model_name": messages[i].ModelName,
|
||||
})
|
||||
break
|
||||
}
|
||||
@@ -559,7 +578,7 @@ func (e *Engine) repairBootstrapReasoningContent(ctx context.Context, dbMsgs, me
|
||||
}
|
||||
|
||||
for i := range overlap {
|
||||
if !messageMatchesIgnoringReasoning(dbMsgs[i], messages[i]) {
|
||||
if !messageMatchesIgnoringReasoningAndModelName(dbMsgs[i], messages[i]) {
|
||||
return false, nil
|
||||
}
|
||||
if dbMsgs[i].ReasoningContent == messages[i].ReasoningContent {
|
||||
@@ -596,6 +615,57 @@ func (e *Engine) repairBootstrapReasoningContent(ctx context.Context, dbMsgs, me
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (e *Engine) repairBootstrapModelName(ctx context.Context, dbMsgs, messages []Message) (bool, error) {
|
||||
if len(dbMsgs) == 0 || len(messages) == 0 {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
overlap := min(len(messages), len(dbMsgs))
|
||||
|
||||
var updates []struct {
|
||||
index int
|
||||
messageID int64
|
||||
modelName string
|
||||
}
|
||||
|
||||
for i := range overlap {
|
||||
if !messageMatchesIgnoringReasoningAndModelName(dbMsgs[i], messages[i]) {
|
||||
return false, nil
|
||||
}
|
||||
if dbMsgs[i].ModelName == messages[i].ModelName {
|
||||
continue
|
||||
}
|
||||
if messages[i].ModelName == "" {
|
||||
return false, nil
|
||||
}
|
||||
updates = append(updates, struct {
|
||||
index int
|
||||
messageID int64
|
||||
modelName string
|
||||
}{
|
||||
index: i,
|
||||
messageID: dbMsgs[i].ID,
|
||||
modelName: messages[i].ModelName,
|
||||
})
|
||||
}
|
||||
|
||||
if len(updates) == 0 {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
for _, update := range updates {
|
||||
if err := e.store.UpdateMessageModelName(ctx, update.messageID, update.modelName); err != nil {
|
||||
return false, err
|
||||
}
|
||||
dbMsgs[update.index].ModelName = update.modelName
|
||||
}
|
||||
|
||||
logger.InfoCF("seahorse", "bootstrap: repaired missing model_name", map[string]any{
|
||||
"messages": len(updates),
|
||||
})
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// truncate shortens a string for logging.
|
||||
func truncate(s string, maxLen int) string {
|
||||
if len(s) <= maxLen {
|
||||
@@ -610,13 +680,20 @@ func truncate(s string, maxLen int) string {
|
||||
// For messages with Parts (tool_use, tool_result), compare Parts instead of Content
|
||||
// because structured messages are matched by their parts payload.
|
||||
func messageMatches(a, b Message) bool {
|
||||
if a.Role != b.Role || a.ReasoningContent != b.ReasoningContent {
|
||||
if a.Role != b.Role || a.ReasoningContent != b.ReasoningContent || a.ModelName != b.ModelName {
|
||||
return false
|
||||
}
|
||||
return messageMatchesIgnoringReasoning(a, b)
|
||||
}
|
||||
|
||||
func messageMatchesIgnoringReasoning(a, b Message) bool {
|
||||
if a.ModelName != b.ModelName {
|
||||
return false
|
||||
}
|
||||
return messageMatchesIgnoringReasoningAndModelName(a, b)
|
||||
}
|
||||
|
||||
func messageMatchesIgnoringReasoningAndModelName(a, b Message) bool {
|
||||
if a.Role != b.Role {
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -25,6 +25,43 @@ func newTestEngine(t *testing.T) *Engine {
|
||||
}
|
||||
}
|
||||
|
||||
func prepareBootstrapRepairConversation(
|
||||
t *testing.T,
|
||||
eng *Engine,
|
||||
ctx context.Context,
|
||||
sessionKey string,
|
||||
) (*Conversation, []Message) {
|
||||
t.Helper()
|
||||
|
||||
conv, err := eng.store.GetOrCreateConversation(ctx, sessionKey)
|
||||
if err != nil {
|
||||
t.Fatalf("GetOrCreateConversation: %v", err)
|
||||
}
|
||||
|
||||
userMsg, err := eng.store.AddMessage(ctx, conv.ConversationID, "user", "hello", 3)
|
||||
if err != nil {
|
||||
t.Fatalf("AddMessage user: %v", err)
|
||||
}
|
||||
|
||||
assistantMsg, err := eng.store.AddMessage(ctx, conv.ConversationID, "assistant", "world", 3)
|
||||
if err != nil {
|
||||
t.Fatalf("AddMessage assistant: %v", err)
|
||||
}
|
||||
|
||||
if err := eng.store.AppendContextMessages(
|
||||
ctx,
|
||||
conv.ConversationID,
|
||||
[]int64{userMsg.ID, assistantMsg.ID},
|
||||
); err != nil {
|
||||
t.Fatalf("AppendContextMessages: %v", err)
|
||||
}
|
||||
|
||||
return conv, []Message{
|
||||
{Role: "user", Content: "hello", TokenCount: 3},
|
||||
{Role: "assistant", Content: "world", TokenCount: 3},
|
||||
}
|
||||
}
|
||||
|
||||
// --- compileSessionPattern ---
|
||||
|
||||
func TestCompileSessionPattern(t *testing.T) {
|
||||
@@ -328,6 +365,7 @@ func TestEngineIngestPreservesReasoningContent(t *testing.T) {
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "world",
|
||||
ModelName: "gpt-5.4-mini",
|
||||
ReasoningContent: "let me think this through",
|
||||
TokenCount: 4,
|
||||
},
|
||||
@@ -353,6 +391,9 @@ func TestEngineIngestPreservesReasoningContent(t *testing.T) {
|
||||
"let me think this through",
|
||||
)
|
||||
}
|
||||
if stored[0].ModelName != "gpt-5.4-mini" {
|
||||
t.Errorf("stored[0].ModelName = %q, want %q", stored[0].ModelName, "gpt-5.4-mini")
|
||||
}
|
||||
|
||||
result, err := eng.Assemble(ctx, "agent:reasoning", AssembleInput{Budget: 1000})
|
||||
if err != nil {
|
||||
@@ -368,6 +409,140 @@ func TestEngineIngestPreservesReasoningContent(t *testing.T) {
|
||||
"let me think this through",
|
||||
)
|
||||
}
|
||||
if result.Messages[0].ModelName != "gpt-5.4-mini" {
|
||||
t.Errorf("assembled model_name = %q, want %q", result.Messages[0].ModelName, "gpt-5.4-mini")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBootstrapRepairsMissingModelName(t *testing.T) {
|
||||
eng := newTestEngine(t)
|
||||
ctx := context.Background()
|
||||
sessionKey := "agent:repair-model-name"
|
||||
conv, msgs := prepareBootstrapRepairConversation(t, eng, ctx, sessionKey)
|
||||
msgs[1].ModelName = "gpt-5.4"
|
||||
|
||||
err := eng.Bootstrap(ctx, sessionKey, msgs)
|
||||
if err != nil {
|
||||
t.Fatalf("Bootstrap: %v", err)
|
||||
}
|
||||
|
||||
stored, err := eng.store.GetMessages(ctx, conv.ConversationID, 10, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("GetMessages: %v", err)
|
||||
}
|
||||
if len(stored) != 2 {
|
||||
t.Fatalf("stored messages = %d, want 2", len(stored))
|
||||
}
|
||||
if stored[1].ModelName != "gpt-5.4" {
|
||||
t.Fatalf("stored[1].ModelName = %q, want %q", stored[1].ModelName, "gpt-5.4")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBootstrapRepairsReasoningContentAndModelNameTogether(t *testing.T) {
|
||||
eng := newTestEngine(t)
|
||||
ctx := context.Background()
|
||||
sessionKey := "agent:repair-both-fields"
|
||||
|
||||
conv, err := eng.store.GetOrCreateConversation(ctx, sessionKey)
|
||||
if err != nil {
|
||||
t.Fatalf("GetOrCreateConversation: %v", err)
|
||||
}
|
||||
|
||||
userMsg, err := eng.store.AddMessage(ctx, conv.ConversationID, "user", "hello", 3)
|
||||
if err != nil {
|
||||
t.Fatalf("AddMessage user: %v", err)
|
||||
}
|
||||
|
||||
assistantMsg, err := eng.store.AddMessage(ctx, conv.ConversationID, "assistant", "world", 3)
|
||||
if err != nil {
|
||||
t.Fatalf("AddMessage assistant: %v", err)
|
||||
}
|
||||
|
||||
err = eng.store.AppendContextMessages(ctx, conv.ConversationID, []int64{userMsg.ID, assistantMsg.ID})
|
||||
if err != nil {
|
||||
t.Fatalf("AppendContextMessages: %v", err)
|
||||
}
|
||||
|
||||
err = eng.Bootstrap(ctx, sessionKey, []Message{
|
||||
{Role: "user", Content: "hello", TokenCount: 3},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "world",
|
||||
ModelName: "gpt-5.4",
|
||||
ReasoningContent: "let me think this through",
|
||||
TokenCount: 3,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Bootstrap: %v", err)
|
||||
}
|
||||
|
||||
stored, err := eng.store.GetMessages(ctx, conv.ConversationID, 10, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("GetMessages: %v", err)
|
||||
}
|
||||
if len(stored) != 2 {
|
||||
t.Fatalf("stored messages = %d, want 2", len(stored))
|
||||
}
|
||||
if stored[1].ReasoningContent != "let me think this through" {
|
||||
t.Fatalf("stored[1].ReasoningContent = %q, want %q", stored[1].ReasoningContent, "let me think this through")
|
||||
}
|
||||
if stored[1].ModelName != "gpt-5.4" {
|
||||
t.Fatalf("stored[1].ModelName = %q, want %q", stored[1].ModelName, "gpt-5.4")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBootstrapRepairsIncorrectNonEmptyModelName(t *testing.T) {
|
||||
eng := newTestEngine(t)
|
||||
ctx := context.Background()
|
||||
sessionKey := "agent:repair-wrong-model-name"
|
||||
|
||||
conv, err := eng.store.GetOrCreateConversation(ctx, sessionKey)
|
||||
if err != nil {
|
||||
t.Fatalf("GetOrCreateConversation: %v", err)
|
||||
}
|
||||
|
||||
userMsg, err := eng.store.AddMessage(ctx, conv.ConversationID, "user", "hello", 3)
|
||||
if err != nil {
|
||||
t.Fatalf("AddMessage user: %v", err)
|
||||
}
|
||||
|
||||
assistantMsg, err := eng.store.AddMessageWithReasoning(
|
||||
ctx,
|
||||
conv.ConversationID,
|
||||
"assistant",
|
||||
"world",
|
||||
"wrong-model",
|
||||
"",
|
||||
3,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("AddMessageWithReasoning assistant: %v", err)
|
||||
}
|
||||
|
||||
err = eng.store.AppendContextMessages(ctx, conv.ConversationID, []int64{userMsg.ID, assistantMsg.ID})
|
||||
if err != nil {
|
||||
t.Fatalf("AppendContextMessages: %v", err)
|
||||
}
|
||||
|
||||
err = eng.Bootstrap(ctx, sessionKey, []Message{
|
||||
{Role: "user", Content: "hello", TokenCount: 3},
|
||||
{Role: "assistant", Content: "world", ModelName: "gpt-5.4", TokenCount: 3},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Bootstrap: %v", err)
|
||||
}
|
||||
|
||||
stored, err := eng.store.GetMessages(ctx, conv.ConversationID, 10, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("GetMessages: %v", err)
|
||||
}
|
||||
if len(stored) != 2 {
|
||||
t.Fatalf("stored messages = %d, want 2", len(stored))
|
||||
}
|
||||
if stored[1].ModelName != "gpt-5.4" {
|
||||
t.Fatalf("stored[1].ModelName = %q, want %q", stored[1].ModelName, "gpt-5.4")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEngineIngestWithPartsPreservesReasoningContent(t *testing.T) {
|
||||
@@ -620,35 +795,10 @@ func TestBootstrapRepairsMissingReasoningContent(t *testing.T) {
|
||||
eng := newTestEngine(t)
|
||||
ctx := context.Background()
|
||||
sessionKey := "agent:repair-reasoning"
|
||||
conv, msgs := prepareBootstrapRepairConversation(t, eng, ctx, sessionKey)
|
||||
msgs[1].ReasoningContent = "let me think this through"
|
||||
|
||||
conv, err := eng.store.GetOrCreateConversation(ctx, sessionKey)
|
||||
if err != nil {
|
||||
t.Fatalf("GetOrCreateConversation: %v", err)
|
||||
}
|
||||
|
||||
userMsg, err := eng.store.AddMessage(ctx, conv.ConversationID, "user", "hello", 3)
|
||||
if err != nil {
|
||||
t.Fatalf("AddMessage user: %v", err)
|
||||
}
|
||||
|
||||
assistantMsg, err := eng.store.AddMessage(ctx, conv.ConversationID, "assistant", "world", 3)
|
||||
if err != nil {
|
||||
t.Fatalf("AddMessage assistant: %v", err)
|
||||
}
|
||||
|
||||
err = eng.store.AppendContextMessages(
|
||||
ctx,
|
||||
conv.ConversationID,
|
||||
[]int64{userMsg.ID, assistantMsg.ID},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("AppendContextMessages: %v", err)
|
||||
}
|
||||
|
||||
err = eng.Bootstrap(ctx, sessionKey, []Message{
|
||||
{Role: "user", Content: "hello", TokenCount: 3},
|
||||
{Role: "assistant", Content: "world", ReasoningContent: "let me think this through", TokenCount: 3},
|
||||
})
|
||||
err := eng.Bootstrap(ctx, sessionKey, msgs)
|
||||
if err != nil {
|
||||
t.Fatalf("Bootstrap: %v", err)
|
||||
}
|
||||
|
||||
+53
-14
@@ -162,19 +162,25 @@ func (s *Store) getMessageTimeRange(ctx context.Context, convID int64) (time.Tim
|
||||
|
||||
// AddMessage appends a message to a conversation.
|
||||
func (s *Store) AddMessage(ctx context.Context, convID int64, role, content string, tokenCount int) (*Message, error) {
|
||||
return s.AddMessageWithReasoning(ctx, convID, role, content, "", tokenCount)
|
||||
return s.AddMessageWithReasoning(ctx, convID, role, content, "", "", tokenCount)
|
||||
}
|
||||
|
||||
// AddMessageWithReasoning appends a message with reasoning content to a conversation.
|
||||
func (s *Store) AddMessageWithReasoning(
|
||||
ctx context.Context,
|
||||
convID int64,
|
||||
role, content, reasoningContent string,
|
||||
role, content, modelName, reasoningContent string,
|
||||
tokenCount int,
|
||||
) (*Message, error) {
|
||||
result, err := s.db.ExecContext(ctx,
|
||||
"INSERT INTO messages (conversation_id, role, content, reasoning_content, token_count) VALUES (?, ?, ?, ?, ?)",
|
||||
convID, role, content, reasoningContent, tokenCount,
|
||||
result, err := s.db.ExecContext(
|
||||
ctx,
|
||||
"INSERT INTO messages (conversation_id, role, content, model_name, reasoning_content, token_count) VALUES (?, ?, ?, ?, ?, ?)",
|
||||
convID,
|
||||
role,
|
||||
content,
|
||||
modelName,
|
||||
reasoningContent,
|
||||
tokenCount,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("add message: %w", err)
|
||||
@@ -185,6 +191,7 @@ func (s *Store) AddMessageWithReasoning(
|
||||
ConversationID: convID,
|
||||
Role: role,
|
||||
Content: content,
|
||||
ModelName: modelName,
|
||||
ReasoningContent: reasoningContent,
|
||||
TokenCount: tokenCount,
|
||||
}, nil
|
||||
@@ -224,7 +231,7 @@ func (s *Store) AddMessageWithParts(
|
||||
parts []MessagePart,
|
||||
tokenCount int,
|
||||
) (*Message, error) {
|
||||
return s.AddMessageWithPartsAndReasoning(ctx, convID, role, parts, "", tokenCount)
|
||||
return s.AddMessageWithPartsAndReasoning(ctx, convID, role, parts, "", "", tokenCount)
|
||||
}
|
||||
|
||||
// AddMessageWithPartsAndReasoning adds a message with structured parts and reasoning content.
|
||||
@@ -233,6 +240,7 @@ func (s *Store) AddMessageWithPartsAndReasoning(
|
||||
convID int64,
|
||||
role string,
|
||||
parts []MessagePart,
|
||||
modelName string,
|
||||
reasoningContent string,
|
||||
tokenCount int,
|
||||
) (*Message, error) {
|
||||
@@ -245,9 +253,15 @@ func (s *Store) AddMessageWithPartsAndReasoning(
|
||||
// Derive readable content from Parts for FTS5 indexing and summary formatting
|
||||
readableContent := partsToReadableContent(parts)
|
||||
|
||||
result, err := tx.ExecContext(ctx,
|
||||
"INSERT INTO messages (conversation_id, role, content, reasoning_content, token_count) VALUES (?, ?, ?, ?, ?)",
|
||||
convID, role, readableContent, reasoningContent, tokenCount,
|
||||
result, err := tx.ExecContext(
|
||||
ctx,
|
||||
"INSERT INTO messages (conversation_id, role, content, model_name, reasoning_content, token_count) VALUES (?, ?, ?, ?, ?, ?)",
|
||||
convID,
|
||||
role,
|
||||
readableContent,
|
||||
modelName,
|
||||
reasoningContent,
|
||||
tokenCount,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("add message: %w", err)
|
||||
@@ -282,6 +296,7 @@ func (s *Store) AddMessageWithPartsAndReasoning(
|
||||
ID: msgID,
|
||||
ConversationID: convID,
|
||||
Role: role,
|
||||
ModelName: modelName,
|
||||
ReasoningContent: reasoningContent,
|
||||
TokenCount: tokenCount,
|
||||
Parts: make([]MessagePart, len(parts)),
|
||||
@@ -295,7 +310,7 @@ func (s *Store) AddMessageWithPartsAndReasoning(
|
||||
|
||||
// GetMessages retrieves messages for a conversation.
|
||||
func (s *Store) GetMessages(ctx context.Context, convID int64, limit int, beforeID int64) ([]Message, error) {
|
||||
query := "SELECT message_id, conversation_id, role, content, reasoning_content, token_count, created_at FROM messages WHERE conversation_id = ?"
|
||||
query := "SELECT message_id, conversation_id, role, content, model_name, reasoning_content, token_count, created_at FROM messages WHERE conversation_id = ?"
|
||||
args := []any{convID}
|
||||
if beforeID > 0 {
|
||||
query += " AND message_id < ?"
|
||||
@@ -322,6 +337,7 @@ func (s *Store) GetMessages(ctx context.Context, convID int64, limit int, before
|
||||
&msg.ConversationID,
|
||||
&msg.Role,
|
||||
&msg.Content,
|
||||
&msg.ModelName,
|
||||
&msg.ReasoningContent,
|
||||
&msg.TokenCount,
|
||||
&createdAt,
|
||||
@@ -362,9 +378,9 @@ func (s *Store) GetMessageByID(ctx context.Context, messageID int64) (*Message,
|
||||
var createdAt string
|
||||
err := s.db.QueryRowContext(
|
||||
ctx,
|
||||
"SELECT message_id, conversation_id, role, content, reasoning_content, token_count, created_at FROM messages WHERE message_id = ?",
|
||||
"SELECT message_id, conversation_id, role, content, model_name, reasoning_content, token_count, created_at FROM messages WHERE message_id = ?",
|
||||
messageID,
|
||||
).Scan(&msg.ID, &msg.ConversationID, &msg.Role, &msg.Content, &msg.ReasoningContent, &msg.TokenCount, &createdAt)
|
||||
).Scan(&msg.ID, &msg.ConversationID, &msg.Role, &msg.Content, &msg.ModelName, &msg.ReasoningContent, &msg.TokenCount, &createdAt)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, fmt.Errorf("message %d not found", messageID)
|
||||
}
|
||||
@@ -398,6 +414,27 @@ func (s *Store) UpdateMessageReasoningContent(ctx context.Context, messageID int
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Store) UpdateMessageModelName(ctx context.Context, messageID int64, modelName string) error {
|
||||
result, err := s.db.ExecContext(
|
||||
ctx,
|
||||
"UPDATE messages SET model_name = ? WHERE message_id = ?",
|
||||
modelName,
|
||||
messageID,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("update message model_name: %w", err)
|
||||
}
|
||||
|
||||
rowsAffected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return fmt.Errorf("update message model_name rows affected: %w", err)
|
||||
}
|
||||
if rowsAffected == 0 {
|
||||
return fmt.Errorf("message %d not found", messageID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Store) loadMessageParts(ctx context.Context, msgID int64) ([]MessagePart, error) {
|
||||
rows, err := s.db.QueryContext(ctx,
|
||||
`SELECT part_id, message_id, type, text, name, arguments, tool_call_id, media_uri, mime_type
|
||||
@@ -581,8 +618,9 @@ func (s *Store) LinkSummaryToMessages(ctx context.Context, summaryID string, mes
|
||||
|
||||
// GetSummarySourceMessages retrieves source messages for a summary.
|
||||
func (s *Store) GetSummarySourceMessages(ctx context.Context, summaryID string) ([]Message, error) {
|
||||
rows, err := s.db.QueryContext(ctx,
|
||||
`SELECT m.message_id, m.conversation_id, m.role, m.content, m.reasoning_content, m.token_count, m.created_at
|
||||
rows, err := s.db.QueryContext(
|
||||
ctx,
|
||||
`SELECT m.message_id, m.conversation_id, m.role, m.content, m.model_name, m.reasoning_content, m.token_count, m.created_at
|
||||
FROM summary_messages sm
|
||||
JOIN messages m ON m.message_id = sm.message_id
|
||||
WHERE sm.summary_id = ?
|
||||
@@ -603,6 +641,7 @@ func (s *Store) GetSummarySourceMessages(ctx context.Context, summaryID string)
|
||||
&msg.ConversationID,
|
||||
&msg.Role,
|
||||
&msg.Content,
|
||||
&msg.ModelName,
|
||||
&msg.ReasoningContent,
|
||||
&msg.TokenCount,
|
||||
&createdAt,
|
||||
|
||||
@@ -210,6 +210,7 @@ func TestStoreAddAndGetMessagesWithReasoningContent(t *testing.T) {
|
||||
conv.ConversationID,
|
||||
"assistant",
|
||||
"hello world",
|
||||
"gpt-5.4-mini",
|
||||
"let me think",
|
||||
5,
|
||||
)
|
||||
@@ -219,6 +220,9 @@ func TestStoreAddAndGetMessagesWithReasoningContent(t *testing.T) {
|
||||
if msg.ReasoningContent != "let me think" {
|
||||
t.Fatalf("ReasoningContent = %q, want %q", msg.ReasoningContent, "let me think")
|
||||
}
|
||||
if msg.ModelName != "gpt-5.4-mini" {
|
||||
t.Fatalf("ModelName = %q, want %q", msg.ModelName, "gpt-5.4-mini")
|
||||
}
|
||||
|
||||
msgs, err := s.GetMessages(ctx, conv.ConversationID, 10, 0)
|
||||
if err != nil {
|
||||
@@ -230,6 +234,9 @@ func TestStoreAddAndGetMessagesWithReasoningContent(t *testing.T) {
|
||||
if msgs[0].ReasoningContent != "let me think" {
|
||||
t.Errorf("ReasoningContent = %q, want %q", msgs[0].ReasoningContent, "let me think")
|
||||
}
|
||||
if msgs[0].ModelName != "gpt-5.4-mini" {
|
||||
t.Errorf("ModelName = %q, want %q", msgs[0].ModelName, "gpt-5.4-mini")
|
||||
}
|
||||
|
||||
found, err := s.GetMessageByID(ctx, msg.ID)
|
||||
if err != nil {
|
||||
@@ -238,6 +245,9 @@ func TestStoreAddAndGetMessagesWithReasoningContent(t *testing.T) {
|
||||
if found.ReasoningContent != "let me think" {
|
||||
t.Errorf("GetMessageByID ReasoningContent = %q, want %q", found.ReasoningContent, "let me think")
|
||||
}
|
||||
if found.ModelName != "gpt-5.4-mini" {
|
||||
t.Errorf("GetMessageByID ModelName = %q, want %q", found.ModelName, "gpt-5.4-mini")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreAddMessageWithParts(t *testing.T) {
|
||||
@@ -288,6 +298,7 @@ func TestStoreAddMessageWithPartsAndReasoningContent(t *testing.T) {
|
||||
conv.ConversationID,
|
||||
"assistant",
|
||||
parts,
|
||||
"gpt-5.4",
|
||||
"need to inspect the file first",
|
||||
10,
|
||||
)
|
||||
@@ -309,6 +320,9 @@ func TestStoreAddMessageWithPartsAndReasoningContent(t *testing.T) {
|
||||
"need to inspect the file first",
|
||||
)
|
||||
}
|
||||
if msgs[0].ModelName != "gpt-5.4" {
|
||||
t.Errorf("ModelName = %q, want %q", msgs[0].ModelName, "gpt-5.4")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreGetMessageCount(t *testing.T) {
|
||||
|
||||
@@ -22,6 +22,7 @@ type Message struct {
|
||||
ConversationID int64 `json:"conversationId"`
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
ModelName string `json:"modelName,omitempty"`
|
||||
ReasoningContent string `json:"reasoningContent,omitempty"`
|
||||
TokenCount int `json:"tokenCount"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
@@ -135,6 +136,7 @@ func EstimateMessageTokens(msg Message) int {
|
||||
pm := providers.Message{
|
||||
Role: msg.Role,
|
||||
Content: msg.Content,
|
||||
ModelName: msg.ModelName,
|
||||
ReasoningContent: msg.ReasoningContent,
|
||||
}
|
||||
|
||||
|
||||
@@ -66,6 +66,25 @@ func TestJSONLBackend_AddFullMessage(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestJSONLBackend_AddFullMessage_PreservesModelName(t *testing.T) {
|
||||
b := newBackend(t)
|
||||
|
||||
msg := providers.Message{
|
||||
Role: "assistant",
|
||||
Content: "done",
|
||||
ModelName: "gpt-5.4-mini",
|
||||
}
|
||||
b.AddFullMessage("s1", msg)
|
||||
|
||||
history := b.GetHistory("s1")
|
||||
if len(history) != 1 {
|
||||
t.Fatalf("got %d, want 1", len(history))
|
||||
}
|
||||
if history[0].ModelName != "gpt-5.4-mini" {
|
||||
t.Fatalf("ModelName = %q, want %q", history[0].ModelName, "gpt-5.4-mini")
|
||||
}
|
||||
}
|
||||
|
||||
func TestJSONLBackend_Summary(t *testing.T) {
|
||||
b := newBackend(t)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user