mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
feat(chat,seahorse): persist and display model_name across history (#2897)
* feat(chat,seahorse): persist and display model_name across history * test(seahorse): fix lint regressions in repair coverage * fix(pico): preserve model_name in live updates * fix(pico): preserve model_name through live stream wrappers
This commit is contained in:
@@ -600,6 +600,12 @@ func (al *AgentLoop) runAgentLoop(
|
||||
Content: result.finalContent,
|
||||
ContextUsage: computeContextUsage(agent, opts.Dispatch.SessionKey),
|
||||
}
|
||||
if modelName := strings.TrimSpace(result.modelName); modelName != "" {
|
||||
if msg.Context.Raw == nil {
|
||||
msg.Context.Raw = make(map[string]string, 1)
|
||||
}
|
||||
msg.Context.Raw["model_name"] = modelName
|
||||
}
|
||||
markFinalOutbound(&msg)
|
||||
al.bus.PublishOutbound(ctx, msg)
|
||||
}
|
||||
|
||||
+32
-11
@@ -102,7 +102,10 @@ func (al *AgentLoop) targetReasoningChannelID(channelName string) (chatID string
|
||||
return ""
|
||||
}
|
||||
|
||||
func (al *AgentLoop) publishPicoReasoning(ctx context.Context, reasoningContent, chatID, sessionKey string) {
|
||||
func (al *AgentLoop) publishPicoReasoning(
|
||||
ctx context.Context,
|
||||
reasoningContent, chatID, sessionKey, modelName string,
|
||||
) {
|
||||
if reasoningContent == "" || chatID == "" {
|
||||
return
|
||||
}
|
||||
@@ -114,13 +117,16 @@ func (al *AgentLoop) publishPicoReasoning(ctx context.Context, reasoningContent,
|
||||
pubCtx, pubCancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
defer pubCancel()
|
||||
|
||||
raw := map[string]string{metadataKeyMessageKind: messageKindThought}
|
||||
if trimmedModelName := strings.TrimSpace(modelName); trimmedModelName != "" {
|
||||
raw["model_name"] = trimmedModelName
|
||||
}
|
||||
|
||||
if err := al.bus.PublishOutbound(pubCtx, bus.OutboundMessage{
|
||||
Context: bus.InboundContext{
|
||||
Channel: "pico",
|
||||
ChatID: chatID,
|
||||
Raw: map[string]string{
|
||||
metadataKeyMessageKind: messageKindThought,
|
||||
},
|
||||
Raw: raw,
|
||||
},
|
||||
SessionKey: sessionKey,
|
||||
Content: reasoningContent,
|
||||
@@ -143,6 +149,7 @@ func (al *AgentLoop) publishPicoReasoning(ctx context.Context, reasoningContent,
|
||||
func (al *AgentLoop) publishPicoToolCallInterim(
|
||||
ctx context.Context,
|
||||
ts *turnState,
|
||||
modelName string,
|
||||
reasoningContent string,
|
||||
content string,
|
||||
toolCalls []providers.ToolCall,
|
||||
@@ -155,7 +162,14 @@ func (al *AgentLoop) publishPicoToolCallInterim(
|
||||
pubCtx, pubCancel := context.WithTimeout(ctx, 3*time.Second)
|
||||
err := al.bus.PublishOutbound(
|
||||
pubCtx,
|
||||
outboundMessageForTurnWithKind(ts, reasoningContent, messageKindThought),
|
||||
outboundMessageForTurnWithOptions(
|
||||
ts,
|
||||
reasoningContent,
|
||||
outboundTurnMessageOptions{
|
||||
kind: messageKindThought,
|
||||
modelName: modelName,
|
||||
},
|
||||
),
|
||||
)
|
||||
pubCancel()
|
||||
if err != nil && !errors.Is(err, context.DeadlineExceeded) &&
|
||||
@@ -182,7 +196,12 @@ func (al *AgentLoop) publishPicoToolCallInterim(
|
||||
|
||||
if strings.TrimSpace(content) != "" && !duplicateToolCallContent {
|
||||
pubCtx, pubCancel := context.WithTimeout(ctx, 3*time.Second)
|
||||
err := al.bus.PublishOutbound(pubCtx, outboundMessageForTurn(ts, content))
|
||||
err := al.bus.PublishOutbound(
|
||||
pubCtx,
|
||||
outboundMessageForTurnWithOptions(ts, content, outboundTurnMessageOptions{
|
||||
modelName: modelName,
|
||||
}),
|
||||
)
|
||||
pubCancel()
|
||||
if err != nil && !errors.Is(err, context.DeadlineExceeded) &&
|
||||
!errors.Is(err, context.Canceled) &&
|
||||
@@ -209,11 +228,13 @@ func (al *AgentLoop) publishPicoToolCallInterim(
|
||||
return
|
||||
}
|
||||
|
||||
msg := outboundMessageForTurnWithKind(ts, "", messageKindToolCalls)
|
||||
if msg.Context.Raw == nil {
|
||||
msg.Context.Raw = map[string]string{}
|
||||
}
|
||||
msg.Context.Raw[metadataKeyToolCalls] = string(rawToolCalls)
|
||||
msg := outboundMessageForTurnWithOptions(ts, "", outboundTurnMessageOptions{
|
||||
kind: messageKindToolCalls,
|
||||
modelName: modelName,
|
||||
raw: map[string]string{
|
||||
metadataKeyToolCalls: string(rawToolCalls),
|
||||
},
|
||||
})
|
||||
|
||||
pubCtx, pubCancel := context.WithTimeout(ctx, 3*time.Second)
|
||||
err = al.bus.PublishOutbound(pubCtx, msg)
|
||||
|
||||
@@ -312,7 +312,7 @@ func TestPublishPicoReasoningIncludesSessionKey(t *testing.T) {
|
||||
defer cleanup()
|
||||
_ = provider
|
||||
|
||||
al.publishPicoReasoning(context.Background(), "reasoning", "pico-chat", "session-1")
|
||||
al.publishPicoReasoning(context.Background(), "reasoning", "pico-chat", "session-1", "")
|
||||
|
||||
select {
|
||||
case outbound := <-msgBus.OutboundChan():
|
||||
|
||||
@@ -96,15 +96,46 @@ func markFinalOutbound(msg *bus.OutboundMessage) {
|
||||
msg.Context.Raw[metadataKeyOutboundKind] = outboundKindFinal
|
||||
}
|
||||
|
||||
func outboundMessageForTurnWithKind(ts *turnState, content, kind string) bus.OutboundMessage {
|
||||
type outboundTurnMessageOptions struct {
|
||||
kind string
|
||||
modelName string
|
||||
raw map[string]string
|
||||
}
|
||||
|
||||
func outboundMessageForTurnWithOptions(
|
||||
ts *turnState,
|
||||
content string,
|
||||
opts outboundTurnMessageOptions,
|
||||
) bus.OutboundMessage {
|
||||
msg := outboundMessageForTurn(ts, content)
|
||||
if strings.TrimSpace(kind) == "" {
|
||||
trimmedKind := strings.TrimSpace(opts.kind)
|
||||
trimmedModelName := strings.TrimSpace(opts.modelName)
|
||||
rawCount := len(opts.raw)
|
||||
if trimmedKind != "" {
|
||||
rawCount++
|
||||
}
|
||||
if trimmedModelName != "" {
|
||||
rawCount++
|
||||
}
|
||||
if rawCount == 0 {
|
||||
return msg
|
||||
}
|
||||
|
||||
if msg.Context.Raw == nil {
|
||||
msg.Context.Raw = make(map[string]string, 1)
|
||||
msg.Context.Raw = make(map[string]string, rawCount)
|
||||
}
|
||||
if trimmedKind != "" {
|
||||
msg.Context.Raw[metadataKeyMessageKind] = trimmedKind
|
||||
}
|
||||
if trimmedModelName != "" {
|
||||
msg.Context.Raw["model_name"] = trimmedModelName
|
||||
}
|
||||
for key, value := range opts.raw {
|
||||
if strings.TrimSpace(key) == "" {
|
||||
continue
|
||||
}
|
||||
msg.Context.Raw[key] = value
|
||||
}
|
||||
msg.Context.Raw[metadataKeyMessageKind] = kind
|
||||
return msg
|
||||
}
|
||||
|
||||
@@ -521,8 +552,9 @@ func hasMediaRefs(messages []providers.Message) bool {
|
||||
|
||||
func sideQuestionModelName(agent *AgentInstance, usedLight bool) string {
|
||||
if usedLight && len(agent.LightCandidates) > 0 {
|
||||
// Use the first light candidate's model
|
||||
return agent.LightCandidates[0].Model
|
||||
if name := resolvedCandidateModelName(agent.LightCandidates, ""); name != "" {
|
||||
return name
|
||||
}
|
||||
}
|
||||
return agent.Model
|
||||
}
|
||||
@@ -538,6 +570,14 @@ func modelNameFromIdentityKey(identityKey string) string {
|
||||
return identityKey
|
||||
}
|
||||
|
||||
func modelAliasFromCandidateIdentityKey(identityKey string) string {
|
||||
const prefix = "model_name:"
|
||||
if !strings.HasPrefix(identityKey, prefix) {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(strings.TrimPrefix(identityKey, prefix))
|
||||
}
|
||||
|
||||
func closeProviderIfStateful(provider providers.LLMProvider) {
|
||||
if stateful, ok := provider.(providers.StatefulProvider); ok {
|
||||
stateful.Close()
|
||||
|
||||
@@ -197,6 +197,7 @@ func providerToSeahorseMessage(msg protocoltypes.Message) seahorse.Message {
|
||||
result := seahorse.Message{
|
||||
Role: msg.Role,
|
||||
Content: msg.Content,
|
||||
ModelName: msg.ModelName,
|
||||
ReasoningContent: msg.ReasoningContent,
|
||||
TokenCount: tokenizer.EstimateMessageTokens(msg),
|
||||
}
|
||||
@@ -243,6 +244,7 @@ func seahorseToProviderMessages(result *seahorse.AssembleResult) []protocoltypes
|
||||
pm := protocoltypes.Message{
|
||||
Role: msg.Role,
|
||||
Content: msg.Content,
|
||||
ModelName: msg.ModelName,
|
||||
ReasoningContent: msg.ReasoningContent,
|
||||
}
|
||||
|
||||
|
||||
@@ -174,6 +174,7 @@ func TestProviderToSeahorseMessageWithReasoning(t *testing.T) {
|
||||
msg := protocoltypes.Message{
|
||||
Role: "assistant",
|
||||
Content: "response text",
|
||||
ModelName: "gpt-5.4-mini",
|
||||
ReasoningContent: "I thought about this carefully",
|
||||
}
|
||||
|
||||
@@ -181,6 +182,9 @@ func TestProviderToSeahorseMessageWithReasoning(t *testing.T) {
|
||||
if result.ReasoningContent != "I thought about this carefully" {
|
||||
t.Errorf("ReasoningContent = %q, want 'I thought about this carefully'", result.ReasoningContent)
|
||||
}
|
||||
if result.ModelName != "gpt-5.4-mini" {
|
||||
t.Errorf("ModelName = %q, want %q", result.ModelName, "gpt-5.4-mini")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSeahorseToProviderMessagesWithReasoning(t *testing.T) {
|
||||
@@ -189,6 +193,7 @@ func TestSeahorseToProviderMessagesWithReasoning(t *testing.T) {
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "response",
|
||||
ModelName: "gpt-5.4",
|
||||
ReasoningContent: "thinking process",
|
||||
},
|
||||
},
|
||||
@@ -201,6 +206,9 @@ func TestSeahorseToProviderMessagesWithReasoning(t *testing.T) {
|
||||
if messages[0].ReasoningContent != "thinking process" {
|
||||
t.Errorf("ReasoningContent = %q, want 'thinking process'", messages[0].ReasoningContent)
|
||||
}
|
||||
if messages[0].ModelName != "gpt-5.4" {
|
||||
t.Errorf("ModelName = %q, want %q", messages[0].ModelName, "gpt-5.4")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSeahorseToProviderMessages(t *testing.T) {
|
||||
|
||||
@@ -75,6 +75,7 @@ func candidateFromModelConfig(
|
||||
return providers.FallbackCandidate{
|
||||
Provider: protocol,
|
||||
Model: modelID,
|
||||
DisplayName: strings.TrimSpace(mc.ModelName),
|
||||
RPM: mc.RPM,
|
||||
IdentityKey: modelConfigIdentityKey(mc),
|
||||
}, true
|
||||
@@ -147,8 +148,9 @@ func resolveModelCandidate(
|
||||
}
|
||||
|
||||
return providers.FallbackCandidate{
|
||||
Provider: ref.Provider,
|
||||
Model: ref.Model,
|
||||
Provider: ref.Provider,
|
||||
Model: ref.Model,
|
||||
DisplayName: raw,
|
||||
}, true
|
||||
}
|
||||
|
||||
@@ -197,6 +199,18 @@ func resolvedCandidateProvider(candidates []providers.FallbackCandidate, fallbac
|
||||
return fallback
|
||||
}
|
||||
|
||||
func resolvedCandidateModelName(candidates []providers.FallbackCandidate, fallback string) string {
|
||||
if len(candidates) > 0 {
|
||||
if name := modelAliasFromCandidateIdentityKey(candidates[0].IdentityKey); strings.TrimSpace(name) != "" {
|
||||
return name
|
||||
}
|
||||
if displayName := strings.TrimSpace(candidates[0].DisplayName); displayName != "" {
|
||||
return displayName
|
||||
}
|
||||
}
|
||||
return strings.TrimSpace(fallback)
|
||||
}
|
||||
|
||||
func resolvedModelConfig(cfg *config.Config, modelName, workspace string) (*config.ModelConfig, error) {
|
||||
if cfg == nil {
|
||||
return nil, fmt.Errorf("config is nil")
|
||||
|
||||
@@ -7,6 +7,55 @@ import (
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
func TestModelNameFromIdentityKey_LegacyProviderModel(t *testing.T) {
|
||||
if got := modelNameFromIdentityKey("openai/gpt-5.4"); got != "gpt-5.4" {
|
||||
t.Fatalf("modelNameFromIdentityKey() = %q, want %q", got, "gpt-5.4")
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelNameFromIdentityKey_PreservesNonLegacyIdentity(t *testing.T) {
|
||||
if got := modelNameFromIdentityKey("model_name:primary"); got != "model_name:primary" {
|
||||
t.Fatalf("modelNameFromIdentityKey() = %q, want %q", got, "model_name:primary")
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelAliasFromCandidateIdentityKey(t *testing.T) {
|
||||
if got := modelAliasFromCandidateIdentityKey("model_name:primary"); got != "primary" {
|
||||
t.Fatalf("modelAliasFromCandidateIdentityKey() = %q, want %q", got, "primary")
|
||||
}
|
||||
if got := modelAliasFromCandidateIdentityKey("openai/gpt-5.4"); got != "" {
|
||||
t.Fatalf("modelAliasFromCandidateIdentityKey() = %q, want empty", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolvedCandidateModelName_PrefersIdentityAlias(t *testing.T) {
|
||||
got := resolvedCandidateModelName([]providers.FallbackCandidate{
|
||||
{Provider: "openai", Model: "gpt-5.4", IdentityKey: "model_name:primary"},
|
||||
}, "fallback-model")
|
||||
if got != "primary" {
|
||||
t.Fatalf("resolvedCandidateModelName() = %q, want %q", got, "primary")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolvedCandidateModelName_DoesNotScanFallbackAliases(t *testing.T) {
|
||||
got := resolvedCandidateModelName([]providers.FallbackCandidate{
|
||||
{Provider: "openai", Model: "gpt-5.4"},
|
||||
{Provider: "openai", Model: "gpt-5.4-mini", IdentityKey: "model_name:fallback"},
|
||||
}, "primary-model")
|
||||
if got != "primary-model" {
|
||||
t.Fatalf("resolvedCandidateModelName() = %q, want %q", got, "primary-model")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolvedCandidateModelName_UsesCandidateDisplayName(t *testing.T) {
|
||||
got := resolvedCandidateModelName([]providers.FallbackCandidate{
|
||||
{Provider: "openai", Model: "gpt-5.4", DisplayName: "gpt-5.4-display"},
|
||||
}, "fallback-model")
|
||||
if got != "gpt-5.4-display" {
|
||||
t.Fatalf("resolvedCandidateModelName() = %q, want %q", got, "gpt-5.4-display")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveActiveModelConfig_PrefersCandidateIdentityKey(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
ModelList: []*config.ModelConfig{
|
||||
|
||||
@@ -180,7 +180,11 @@ toolLoop:
|
||||
toolFeedbackArgsPreview(toolArgs, toolFeedbackMaxLen),
|
||||
)
|
||||
fbCtx, fbCancel := context.WithTimeout(turnCtx, 3*time.Second)
|
||||
_ = al.bus.PublishOutbound(fbCtx, outboundMessageForTurnWithKind(ts, feedbackMsg, messageKindToolFeedback))
|
||||
_ = al.bus.PublishOutbound(fbCtx, outboundMessageForTurnWithOptions(
|
||||
ts,
|
||||
feedbackMsg,
|
||||
outboundTurnMessageOptions{kind: messageKindToolFeedback},
|
||||
))
|
||||
fbCancel()
|
||||
}
|
||||
|
||||
@@ -467,7 +471,11 @@ toolLoop:
|
||||
toolFeedbackArgsPreview(toolArgs, toolFeedbackMaxLen),
|
||||
)
|
||||
fbCtx, fbCancel := context.WithTimeout(turnCtx, 3*time.Second)
|
||||
_ = al.bus.PublishOutbound(fbCtx, outboundMessageForTurnWithKind(ts, feedbackMsg, messageKindToolFeedback))
|
||||
_ = al.bus.PublishOutbound(fbCtx, outboundMessageForTurnWithOptions(
|
||||
ts,
|
||||
feedbackMsg,
|
||||
outboundTurnMessageOptions{kind: messageKindToolFeedback},
|
||||
))
|
||||
fbCancel()
|
||||
}
|
||||
|
||||
|
||||
@@ -33,6 +33,7 @@ func (p *Pipeline) Finalize(
|
||||
ts.setPhase(TurnPhaseCompleted)
|
||||
return turnResult{
|
||||
finalContent: finalContent,
|
||||
modelName: exec.llmModelName,
|
||||
status: turnStatus,
|
||||
followUps: append([]bus.InboundMessage(nil), ts.followUps...),
|
||||
}, nil
|
||||
@@ -44,6 +45,7 @@ func (p *Pipeline) Finalize(
|
||||
finalMsg := providers.Message{
|
||||
Role: "assistant",
|
||||
Content: finalContent,
|
||||
ModelName: exec.llmModelName,
|
||||
ReasoningContent: responseReasoningContent(exec.response),
|
||||
}
|
||||
ts.agent.Sessions.AddFullMessage(ts.sessionKey, finalMsg)
|
||||
@@ -80,24 +82,10 @@ func (p *Pipeline) Finalize(
|
||||
// so the final answer is still delivered outside normal SendResponse.
|
||||
if ((streamErr != nil && !isConfiguredStreamingVisibleError(streamErr)) || exec.streamingFallback) &&
|
||||
!ts.opts.SendResponse && ts.opts.AllowInterimPicoPublish && finalContent != "" {
|
||||
agentID, sessionKey, scope := outboundTurnMetadata(
|
||||
ts.agent.ID,
|
||||
ts.opts.Dispatch.SessionKey,
|
||||
ts.opts.Dispatch.SessionScope,
|
||||
)
|
||||
msg := bus.OutboundMessage{
|
||||
Context: outboundContextFromInbound(
|
||||
ts.opts.Dispatch.InboundContext,
|
||||
ts.opts.Dispatch.Channel(),
|
||||
ts.opts.Dispatch.ChatID(),
|
||||
ts.opts.Dispatch.ReplyToMessageID(),
|
||||
),
|
||||
AgentID: agentID,
|
||||
SessionKey: sessionKey,
|
||||
Scope: scope,
|
||||
Content: finalContent,
|
||||
ContextUsage: contextUsage,
|
||||
}
|
||||
msg := outboundMessageForTurnWithOptions(ts, finalContent, outboundTurnMessageOptions{
|
||||
modelName: exec.llmModelName,
|
||||
})
|
||||
msg.ContextUsage = contextUsage
|
||||
markFinalOutbound(&msg)
|
||||
_ = al.bus.PublishOutbound(turnCtx, msg)
|
||||
}
|
||||
@@ -112,6 +100,7 @@ func (p *Pipeline) Finalize(
|
||||
ts.setPhase(TurnPhaseCompleted)
|
||||
return turnResult{
|
||||
finalContent: finalContent,
|
||||
modelName: exec.llmModelName,
|
||||
status: turnStatus,
|
||||
followUps: append([]bus.InboundMessage(nil), ts.followUps...),
|
||||
}, nil
|
||||
|
||||
@@ -200,6 +200,16 @@ func (p *Pipeline) CallLLM(
|
||||
map[string]any{"agent_id": ts.agent.ID, "iteration": iteration},
|
||||
)
|
||||
}
|
||||
for _, candidate := range exec.activeCandidates {
|
||||
if candidate.StableKey() != fbResult.IdentityKey {
|
||||
continue
|
||||
}
|
||||
exec.llmModelName = resolvedCandidateModelName(
|
||||
[]providers.FallbackCandidate{candidate},
|
||||
exec.llmModelName,
|
||||
)
|
||||
break
|
||||
}
|
||||
return fbResult.Response, nil
|
||||
}
|
||||
return exec.activeProvider.Chat(providerCtx, messagesForCall, toolDefsForCall, exec.llmModel, exec.llmOpts)
|
||||
@@ -477,7 +487,7 @@ func (p *Pipeline) CallLLM(
|
||||
// Publish pico thoughts before the turn context is canceled at return time.
|
||||
// The async variant can race with turn teardown and intermittently drop the
|
||||
// thought message in CI even though the LLM produced reasoning content.
|
||||
al.publishPicoReasoning(turnCtx, reasoningContent, ts.chatID, ts.sessionKey)
|
||||
al.publishPicoReasoning(turnCtx, reasoningContent, ts.chatID, ts.sessionKey, exec.llmModelName)
|
||||
}
|
||||
} else {
|
||||
go al.handleReasoning(
|
||||
@@ -564,6 +574,7 @@ func (p *Pipeline) CallLLM(
|
||||
assistantMsg := providers.Message{
|
||||
Role: "assistant",
|
||||
Content: exec.response.Content,
|
||||
ModelName: exec.llmModelName,
|
||||
ReasoningContent: reasoningContent,
|
||||
}
|
||||
for _, tc := range exec.normalizedToolCalls {
|
||||
@@ -607,6 +618,7 @@ func (p *Pipeline) CallLLM(
|
||||
al.publishPicoToolCallInterim(
|
||||
turnCtx,
|
||||
ts,
|
||||
exec.llmModelName,
|
||||
reasoningContent,
|
||||
exec.response.Content,
|
||||
assistantMsg.ToolCalls,
|
||||
|
||||
@@ -89,6 +89,11 @@ func (p *Pipeline) SetupTurn(ctx context.Context, ts *turnState) (*turnExecution
|
||||
if usedLight && ts.agent.LightProvider != nil {
|
||||
activeProvider = ts.agent.LightProvider
|
||||
}
|
||||
activeModelName := strings.TrimSpace(ts.agent.Model)
|
||||
if usedLight {
|
||||
activeModelName = strings.TrimSpace(sideQuestionModelName(ts.agent, true))
|
||||
}
|
||||
activeModelName = resolvedCandidateModelName(activeCandidates, activeModelName)
|
||||
|
||||
exec := newTurnExecution(
|
||||
ts.agent,
|
||||
@@ -106,6 +111,7 @@ func (p *Pipeline) SetupTurn(ctx context.Context, ts *turnState) (*turnExecution
|
||||
activeModel,
|
||||
p.Cfg.Agents.Defaults.Provider,
|
||||
)
|
||||
exec.llmModelName = activeModelName
|
||||
exec.activeProvider = activeProvider
|
||||
exec.usedLight = usedLight
|
||||
|
||||
|
||||
@@ -50,9 +50,10 @@ func (p *Pipeline) tryConfiguredStreamingLLM(
|
||||
}
|
||||
|
||||
publisher := &streamingChunkPublisher{
|
||||
streamer: streamer,
|
||||
channel: ts.channel,
|
||||
chatID: ts.chatID,
|
||||
streamer: streamer,
|
||||
channel: ts.channel,
|
||||
chatID: ts.chatID,
|
||||
modelName: exec.llmModelName,
|
||||
}
|
||||
|
||||
logger.DebugCF("agent", "configured streaming enabled", map[string]any{
|
||||
@@ -371,6 +372,7 @@ type streamingChunkPublisher struct {
|
||||
streamer bus.Streamer
|
||||
channel string
|
||||
chatID string
|
||||
modelName string
|
||||
published bool
|
||||
reasoningPublished bool
|
||||
err error
|
||||
@@ -380,6 +382,9 @@ func (p *streamingChunkPublisher) Update(ctx context.Context, accumulated string
|
||||
if p == nil || p.streamer == nil || strings.TrimSpace(accumulated) == "" {
|
||||
return
|
||||
}
|
||||
if setter, ok := p.streamer.(interface{ SetModelName(modelName string) }); ok {
|
||||
setter.SetModelName(p.modelName)
|
||||
}
|
||||
if err := p.streamer.Update(ctx, accumulated); err != nil {
|
||||
p.err = err
|
||||
logger.WarnCF("agent", "stream update failed", map[string]any{
|
||||
@@ -396,6 +401,9 @@ func (p *streamingChunkPublisher) UpdateReasoning(ctx context.Context, accumulat
|
||||
if p == nil || p.streamer == nil || strings.TrimSpace(accumulated) == "" {
|
||||
return
|
||||
}
|
||||
if setter, ok := p.streamer.(interface{ SetModelName(modelName string) }); ok {
|
||||
setter.SetModelName(p.modelName)
|
||||
}
|
||||
reasoningStreamer, ok := p.streamer.(bus.ReasoningStreamer)
|
||||
if !ok {
|
||||
return
|
||||
@@ -434,6 +442,9 @@ func (p *streamingChunkPublisher) Finalize(ctx context.Context, content string,
|
||||
if strings.TrimSpace(content) == "" && !p.published {
|
||||
return nil
|
||||
}
|
||||
if setter, ok := p.streamer.(interface{ SetModelName(modelName string) }); ok {
|
||||
setter.SetModelName(p.modelName)
|
||||
}
|
||||
var err error
|
||||
if streamer, ok := p.streamer.(bus.ContextUsageStreamer); ok {
|
||||
err = streamer.FinalizeWithContext(ctx, content, contextUsage)
|
||||
|
||||
@@ -570,6 +570,9 @@ func TestConfiguredStreamingFinalFlushFailureBeforeVisibleOutputPublishesFallbac
|
||||
if outbound.Content != "stream response" {
|
||||
t.Fatalf("fallback outbound content = %q, want stream response", outbound.Content)
|
||||
}
|
||||
if got := outbound.Context.Raw["model_name"]; got != "test-model" {
|
||||
t.Fatalf("fallback outbound model_name = %q, want %q", got, "test-model")
|
||||
}
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("expected fallback outbound after invisible final stream flush failure")
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/routing"
|
||||
"github.com/sipeed/picoclaw/pkg/session"
|
||||
)
|
||||
|
||||
@@ -37,6 +38,39 @@ func (p *simpleConvProvider) GetDefaultModel() string {
|
||||
return "simple-model"
|
||||
}
|
||||
|
||||
type sequenceProvider struct {
|
||||
responses []*providers.LLMResponse
|
||||
errors []error
|
||||
callCount int
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func (p *sequenceProvider) Chat(
|
||||
ctx context.Context,
|
||||
messages []providers.Message,
|
||||
tools []providers.ToolDefinition,
|
||||
model string,
|
||||
opts map[string]any,
|
||||
) (*providers.LLMResponse, error) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
idx := p.callCount
|
||||
p.callCount++
|
||||
|
||||
if idx < len(p.errors) && p.errors[idx] != nil {
|
||||
return nil, p.errors[idx]
|
||||
}
|
||||
if idx < len(p.responses) && p.responses[idx] != nil {
|
||||
return p.responses[idx], nil
|
||||
}
|
||||
return &providers.LLMResponse{Content: "ok", FinishReason: "stop"}, nil
|
||||
}
|
||||
|
||||
func (p *sequenceProvider) GetDefaultModel() string {
|
||||
return "sequence-model"
|
||||
}
|
||||
|
||||
type nativeSearchCaptureProvider struct {
|
||||
lastOpts map[string]any
|
||||
}
|
||||
@@ -271,6 +305,152 @@ func TestPipeline_CallLLM_SimpleResponse(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestPipeline_SetupTurn_ModelNameDoesNotUseFallbackAliasBeforeFallback(t *testing.T) {
|
||||
al, agent, cleanup := newTurnCoordTestLoop(t, &simpleConvProvider{})
|
||||
defer cleanup()
|
||||
|
||||
agent.Model = "primary-model"
|
||||
agent.Candidates = []providers.FallbackCandidate{
|
||||
{Provider: "openai", Model: "gpt-5.4"},
|
||||
{Provider: "anthropic", Model: "claude-sonnet", IdentityKey: "model_name:fallback-model"},
|
||||
}
|
||||
|
||||
pipeline := NewPipeline(al)
|
||||
ts := newTurnState(agent, makeTestProcessOpts("test-session"), turnEventScope{
|
||||
turnID: "turn-1",
|
||||
context: newTurnContext(nil, nil, nil),
|
||||
})
|
||||
|
||||
exec, err := pipeline.SetupTurn(context.Background(), ts)
|
||||
if err != nil {
|
||||
t.Fatalf("SetupTurn failed: %v", err)
|
||||
}
|
||||
if exec.llmModelName != "primary-model" {
|
||||
t.Fatalf("exec.llmModelName = %q, want %q", exec.llmModelName, "primary-model")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPipeline_CallLLM_UsesSuccessfulFallbackIdentityAlias(t *testing.T) {
|
||||
provider := &sequenceProvider{
|
||||
errors: []error{
|
||||
errors.New("status: 429 - rate limit exceeded"),
|
||||
nil,
|
||||
},
|
||||
responses: []*providers.LLMResponse{
|
||||
nil,
|
||||
{Content: "fallback answer", FinishReason: "stop"},
|
||||
},
|
||||
}
|
||||
al, agent, cleanup := newTurnCoordTestLoop(t, provider)
|
||||
defer cleanup()
|
||||
|
||||
agent.Model = "primary-model"
|
||||
agent.Candidates = []providers.FallbackCandidate{
|
||||
{Provider: "openai", Model: "gpt-5.4", IdentityKey: "model_name:primary"},
|
||||
{Provider: "openai", Model: "gpt-5.4", IdentityKey: "model_name:secondary"},
|
||||
}
|
||||
al.fallback = providers.NewFallbackChain(providers.NewCooldownTracker(), nil)
|
||||
|
||||
pipeline := NewPipeline(al)
|
||||
ts := newTurnState(agent, makeTestProcessOpts("test-session"), turnEventScope{
|
||||
turnID: "turn-1",
|
||||
context: newTurnContext(nil, nil, nil),
|
||||
})
|
||||
|
||||
exec, err := pipeline.SetupTurn(context.Background(), ts)
|
||||
if err != nil {
|
||||
t.Fatalf("SetupTurn failed: %v", err)
|
||||
}
|
||||
|
||||
ctrl, err := pipeline.CallLLM(context.Background(), context.Background(), ts, exec, 1)
|
||||
if err != nil {
|
||||
t.Fatalf("CallLLM failed: %v", err)
|
||||
}
|
||||
if ctrl != ControlBreak {
|
||||
t.Fatalf("expected ControlBreak, got %v", ctrl)
|
||||
}
|
||||
if exec.llmModelName != "secondary" {
|
||||
t.Fatalf("exec.llmModelName = %q, want %q", exec.llmModelName, "secondary")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPipeline_CallLLM_UsesSuccessfulFallbackDisplayNameWithoutAlias(t *testing.T) {
|
||||
provider := &sequenceProvider{
|
||||
errors: []error{
|
||||
errors.New("status: 429 - rate limit exceeded"),
|
||||
nil,
|
||||
},
|
||||
responses: []*providers.LLMResponse{
|
||||
nil,
|
||||
{Content: "fallback answer", FinishReason: "stop"},
|
||||
},
|
||||
}
|
||||
al, agent, cleanup := newTurnCoordTestLoop(t, provider)
|
||||
defer cleanup()
|
||||
|
||||
agent.Model = "primary-model"
|
||||
agent.Candidates = []providers.FallbackCandidate{
|
||||
{Provider: "openai", Model: "gpt-5.4", IdentityKey: "model_name:primary", DisplayName: "primary-model"},
|
||||
{Provider: "anthropic", Model: "claude-sonnet", DisplayName: "anthropic/claude-sonnet"},
|
||||
}
|
||||
al.fallback = providers.NewFallbackChain(providers.NewCooldownTracker(), nil)
|
||||
|
||||
pipeline := NewPipeline(al)
|
||||
ts := newTurnState(agent, makeTestProcessOpts("test-session"), turnEventScope{
|
||||
turnID: "turn-1",
|
||||
context: newTurnContext(nil, nil, nil),
|
||||
})
|
||||
|
||||
exec, err := pipeline.SetupTurn(context.Background(), ts)
|
||||
if err != nil {
|
||||
t.Fatalf("SetupTurn failed: %v", err)
|
||||
}
|
||||
|
||||
ctrl, err := pipeline.CallLLM(context.Background(), context.Background(), ts, exec, 1)
|
||||
if err != nil {
|
||||
t.Fatalf("CallLLM failed: %v", err)
|
||||
}
|
||||
if ctrl != ControlBreak {
|
||||
t.Fatalf("expected ControlBreak, got %v", ctrl)
|
||||
}
|
||||
if exec.llmModelName != "anthropic/claude-sonnet" {
|
||||
t.Fatalf("exec.llmModelName = %q, want %q", exec.llmModelName, "anthropic/claude-sonnet")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPipeline_SetupTurn_UsesLightCandidateDisplayName(t *testing.T) {
|
||||
al, agent, cleanup := newTurnCoordTestLoop(t, &simpleConvProvider{})
|
||||
defer cleanup()
|
||||
|
||||
agent.Model = "primary-model"
|
||||
agent.Candidates = []providers.FallbackCandidate{
|
||||
{Provider: "openai", Model: "gpt-5.4", IdentityKey: "model_name:primary", DisplayName: "primary-model"},
|
||||
}
|
||||
agent.LightCandidates = []providers.FallbackCandidate{
|
||||
{Provider: "openai", Model: "gpt-5.4-mini", IdentityKey: "model_name:light-model", DisplayName: "light-model"},
|
||||
}
|
||||
agent.Router = routing.New(routing.RouterConfig{LightModel: "light-model", Threshold: 1})
|
||||
|
||||
pipeline := NewPipeline(al)
|
||||
opts := makeTestProcessOpts("test-session")
|
||||
opts.UserMessage = ""
|
||||
ts := newTurnState(agent, opts, turnEventScope{
|
||||
turnID: "turn-1",
|
||||
context: newTurnContext(nil, nil, nil),
|
||||
})
|
||||
|
||||
exec, err := pipeline.SetupTurn(context.Background(), ts)
|
||||
if err != nil {
|
||||
t.Fatalf("SetupTurn failed: %v", err)
|
||||
}
|
||||
if !exec.usedLight {
|
||||
t.Fatal("expected light routing to be used")
|
||||
}
|
||||
if exec.llmModelName != "light-model" {
|
||||
t.Fatalf("exec.llmModelName = %q, want %q", exec.llmModelName, "light-model")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunTurn_FinalizeSaveErrorEmitsErrorTurnEnd(t *testing.T) {
|
||||
al, agent, cleanup := newTurnCoordTestLoop(t, &simpleConvProvider{})
|
||||
defer cleanup()
|
||||
|
||||
@@ -84,6 +84,7 @@ const (
|
||||
|
||||
type turnResult struct {
|
||||
finalContent string
|
||||
modelName string
|
||||
status TurnEndStatus
|
||||
followUps []bus.InboundMessage
|
||||
}
|
||||
@@ -140,6 +141,7 @@ type turnExecution struct {
|
||||
callMessages []providers.Message
|
||||
providerToolDefs []providers.ToolDefinition
|
||||
llmModel string
|
||||
llmModelName string
|
||||
llmOpts map[string]any
|
||||
gracefulTerminal bool
|
||||
useNativeSearch bool
|
||||
|
||||
Reference in New Issue
Block a user