fix(pico): stream assistant text between tool calls

This commit is contained in:
lc6464
2026-04-09 22:32:35 +08:00
parent 5b596ed2f0
commit 2aeed8fb3a
2 changed files with 113 additions and 1 deletions
+15 -1
View File
@@ -1409,7 +1409,7 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage)
Media: msg.Media, Media: msg.Media,
DefaultResponse: defaultResponse, DefaultResponse: defaultResponse,
EnableSummary: true, EnableSummary: true,
SendResponse: msg.Channel == "pico", SendResponse: false,
} }
// context-dependent commands check their own Runtime fields and report // context-dependent commands check their own Runtime fields and report
@@ -2253,6 +2253,20 @@ turnLoop:
} }
logger.DebugCF("agent", "LLM response", llmResponseFields) logger.DebugCF("agent", "LLM response", llmResponseFields)
if al.bus != nil && ts.channel == "pico" {
liveContent := response.Content
if liveContent == "" && len(response.ToolCalls) == 0 && response.ReasoningContent != "" {
liveContent = response.ReasoningContent
}
if strings.TrimSpace(liveContent) != "" {
al.bus.PublishOutbound(turnCtx, bus.OutboundMessage{
Channel: ts.channel,
ChatID: ts.chatID,
Content: liveContent,
})
}
}
if len(response.ToolCalls) == 0 || gracefulTerminal { if len(response.ToolCalls) == 0 || gracefulTerminal {
responseContent := response.Content responseContent := response.Content
if responseContent == "" && response.ReasoningContent != "" { if responseContent == "" && response.ReasoningContent != "" {
+98
View File
@@ -1069,6 +1069,40 @@ func (m *toolFeedbackProvider) GetDefaultModel() string {
return "heartbeat-tool-feedback-model" return "heartbeat-tool-feedback-model"
} }
type picoInterleavedContentProvider struct {
calls int
}
func (m *picoInterleavedContentProvider) Chat(
ctx context.Context,
messages []providers.Message,
tools []providers.ToolDefinition,
model string,
opts map[string]any,
) (*providers.LLMResponse, error) {
m.calls++
if m.calls == 1 {
return &providers.LLMResponse{
Content: "intermediate model text",
ToolCalls: []providers.ToolCall{{
ID: "call_tool_limit_test",
Type: "function",
Name: "tool_limit_test_tool",
Arguments: map[string]any{"value": "x"},
}},
}, nil
}
return &providers.LLMResponse{
Content: "final model text",
ToolCalls: []providers.ToolCall{},
}, nil
}
func (m *picoInterleavedContentProvider) GetDefaultModel() string {
return "pico-interleaved-content-model"
}
type toolLimitOnlyProvider struct{} type toolLimitOnlyProvider struct{}
func (m *toolLimitOnlyProvider) Chat( func (m *toolLimitOnlyProvider) Chat(
@@ -2732,6 +2766,70 @@ func TestProcessMessage_PublishesToolFeedbackWhenEnabled(t *testing.T) {
} }
} }
func TestProcessMessage_PicoPublishesAssistantContentDuringToolCalls(t *testing.T) {
tmpDir := t.TempDir()
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
ModelName: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
msgBus := bus.NewMessageBus()
provider := &picoInterleavedContentProvider{}
al := NewAgentLoop(cfg, msgBus, provider)
agent := al.GetRegistry().GetDefaultAgent()
if agent == nil {
t.Fatal("expected default agent")
}
agent.Tools.Register(&toolLimitTestTool{})
response, err := al.processMessage(context.Background(), bus.InboundMessage{
Channel: "pico",
SenderID: "user-1",
ChatID: "session-1",
Content: "run with tools",
})
if err != nil {
t.Fatalf("processMessage() error = %v", err)
}
if response != "final model text" {
t.Fatalf("processMessage() response = %q, want %q", response, "final model text")
}
outputs := make([]string, 0, 2)
deadline := time.After(2 * time.Second)
for len(outputs) < 2 {
select {
case outbound := <-msgBus.OutboundChan():
outputs = append(outputs, outbound.Content)
case <-deadline:
t.Fatalf("timed out waiting for pico outputs, got %v", outputs)
}
}
if outputs[0] != "intermediate model text" {
t.Fatalf("first outbound content = %q, want %q", outputs[0], "intermediate model text")
}
if outputs[1] != "final model text" {
t.Fatalf("second outbound content = %q, want %q", outputs[1], "final model text")
}
select {
case outbound := <-msgBus.OutboundChan():
if outbound.Content == "final model text" {
t.Fatalf("unexpected duplicate final pico output: %+v", outbound)
}
case <-time.After(200 * time.Millisecond):
}
}
func TestResolveMediaRefs_ResolvesToBase64(t *testing.T) { func TestResolveMediaRefs_ResolvesToBase64(t *testing.T) {
store := media.NewFileMediaStore() store := media.NewFileMediaStore()
dir := t.TempDir() dir := t.TempDir()