diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index 6a188416d..899c233cb 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -2413,6 +2413,47 @@ turnLoop: if toolResult == nil { toolResult = tools.ErrorResult("hook returned nil tool result") } + if len(toolResult.Media) > 0 && toolResult.ResponseHandled { + parts := make([]bus.MediaPart, 0, len(toolResult.Media)) + for _, ref := range toolResult.Media { + part := bus.MediaPart{Ref: ref} + if al.mediaStore != nil { + if _, meta, err := al.mediaStore.ResolveWithMeta(ref); err == nil { + part.Filename = meta.Filename + part.ContentType = meta.ContentType + part.Type = inferMediaType(meta.Filename, meta.ContentType) + } + } + parts = append(parts, part) + } + outboundMedia := bus.OutboundMediaMessage{ + Channel: ts.channel, + ChatID: ts.chatID, + Parts: parts, + } + if al.channelManager != nil && ts.channel != "" && !constants.IsInternalChannel(ts.channel) { + if err := al.channelManager.SendMedia(ctx, outboundMedia); err != nil { + logger.WarnCF("agent", "Failed to deliver handled tool media", + map[string]any{ + "agent_id": ts.agent.ID, + "tool": toolName, + "channel": ts.channel, + "chat_id": ts.chatID, + "error": err.Error(), + }) + toolResult = tools.ErrorResult(fmt.Sprintf("failed to deliver attachment: %v", err)).WithError(err) + } + } else if al.bus != nil { + al.bus.PublishOutboundMedia(ctx, outboundMedia) + // Queuing media is only best-effort; it has not been delivered yet. + toolResult.ResponseHandled = false + } + } + + if len(toolResult.Media) > 0 && !toolResult.ResponseHandled { + toolResult.ArtifactTags = buildArtifactTags(al.mediaStore, toolResult.Media) + } + if !toolResult.ResponseHandled { allResponsesHandled = false } @@ -2430,29 +2471,6 @@ turnLoop: }) } - if len(toolResult.Media) > 0 && toolResult.ResponseHandled { - parts := make([]bus.MediaPart, 0, len(toolResult.Media)) - for _, ref := range toolResult.Media { - part := bus.MediaPart{Ref: ref} - if al.mediaStore != nil { - if _, meta, err := al.mediaStore.ResolveWithMeta(ref); err == nil { - part.Filename = meta.Filename - part.ContentType = meta.ContentType - part.Type = inferMediaType(meta.Filename, meta.ContentType) - } - } - parts = append(parts, part) - } - al.bus.PublishOutboundMedia(ctx, bus.OutboundMediaMessage{ - Channel: ts.channel, - ChatID: ts.chatID, - Parts: parts, - }) - } - - if len(toolResult.Media) > 0 && !toolResult.ResponseHandled { - toolResult.ArtifactTags = buildArtifactTags(al.mediaStore, toolResult.Media) - } contentForLLM := toolResult.ContentForLLM() toolResultMsg := providers.Message{ @@ -2543,6 +2561,29 @@ turnLoop: } if allResponsesHandled { + if len(pendingMessages) > 0 { + logger.InfoCF("agent", "Pending steering exists after handled tool delivery; continuing turn before finalizing", + map[string]any{ + "agent_id": ts.agent.ID, + "steering_count": len(pendingMessages), + "session_key": ts.sessionKey, + }) + finalContent = "" + goto turnLoop + } + + if steerMsgs := al.dequeueSteeringMessagesForScope(ts.sessionKey); len(steerMsgs) > 0 { + logger.InfoCF("agent", "Steering arrived after handled tool delivery; continuing turn before finalizing", + map[string]any{ + "agent_id": ts.agent.ID, + "steering_count": len(steerMsgs), + "session_key": ts.sessionKey, + }) + pendingMessages = append(pendingMessages, steerMsgs...) + finalContent = "" + goto turnLoop + } + summaryMsg := providers.Message{ Role: "assistant", Content: handledToolResponseSummary, diff --git a/pkg/agent/loop_test.go b/pkg/agent/loop_test.go index ffb87d7dd..2bf544595 100644 --- a/pkg/agent/loop_test.go +++ b/pkg/agent/loop_test.go @@ -33,6 +33,41 @@ func (f *fakeChannel) IsAllowed(string) bool { func (f *fakeChannel) IsAllowedSender(sender bus.SenderInfo) bool { return true } func (f *fakeChannel) ReasoningChannelID() string { return f.id } +type fakeMediaChannel struct { + fakeChannel + sentMedia []bus.OutboundMediaMessage +} + +func (f *fakeMediaChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error { + f.sentMedia = append(f.sentMedia, msg) + return nil +} + +func newStartedTestChannelManager( + t *testing.T, + msgBus *bus.MessageBus, + store media.MediaStore, + name string, + ch channels.Channel, +) *channels.Manager { + t.Helper() + + cm, err := channels.NewManager(&config.Config{}, msgBus, store) + if err != nil { + t.Fatalf("NewManager() error = %v", err) + } + cm.RegisterChannel(name, ch) + if err := cm.StartAll(context.Background()); err != nil { + t.Fatalf("StartAll() error = %v", err) + } + t.Cleanup(func() { + if err := cm.StopAll(context.Background()); err != nil { + t.Fatalf("StopAll() error = %v", err) + } + }) + return cm +} + type recordingProvider struct { lastMessages []providers.Message } @@ -554,6 +589,8 @@ func TestProcessMessage_MediaToolHandledSkipsFollowUpLLMAndFinalText(t *testing. store := media.NewFileMediaStore() al.SetMediaStore(store) + telegramChannel := &fakeMediaChannel{fakeChannel: fakeChannel{id: "rid-telegram"}} + al.SetChannelManager(newStartedTestChannelManager(t, msgBus, store, "telegram", telegramChannel)) imagePath := filepath.Join(tmpDir, "screen.png") if err := os.WriteFile(imagePath, []byte("fake screenshot"), 0o644); err != nil { @@ -587,16 +624,20 @@ func TestProcessMessage_MediaToolHandledSkipsFollowUpLLMAndFinalText(t *testing. t.Fatal("expected tools to be available on the first LLM call") } + if len(telegramChannel.sentMedia) != 1 { + t.Fatalf("expected exactly 1 synchronously sent media message, got %d", len(telegramChannel.sentMedia)) + } + if telegramChannel.sentMedia[0].Channel != "telegram" || telegramChannel.sentMedia[0].ChatID != "chat1" { + t.Fatalf("unexpected sent media target: %+v", telegramChannel.sentMedia[0]) + } + if len(telegramChannel.sentMedia[0].Parts) != 1 { + t.Fatalf("expected exactly 1 sent media part, got %d", len(telegramChannel.sentMedia[0].Parts)) + } + select { - case mediaMsg := <-msgBus.OutboundMediaChan(): - if mediaMsg.Channel != "telegram" || mediaMsg.ChatID != "chat1" { - t.Fatalf("unexpected outbound media target: %+v", mediaMsg) - } - if len(mediaMsg.Parts) != 1 { - t.Fatalf("expected exactly 1 outbound media part, got %d", len(mediaMsg.Parts)) - } + case extra := <-msgBus.OutboundMediaChan(): + t.Fatalf("expected handled media to bypass async queue, got %+v", extra) default: - t.Fatal("expected outbound media message to be published") } defaultAgent := al.GetRegistry().GetDefaultAgent() @@ -623,6 +664,59 @@ func TestProcessMessage_MediaToolHandledSkipsFollowUpLLMAndFinalText(t *testing. } } +func TestProcessMessage_HandledToolProcessesQueuedSteeringBeforeReturning(t *testing.T) { + tmpDir := t.TempDir() + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + ModelName: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + + msgBus := bus.NewMessageBus() + provider := &handledMediaWithSteeringProvider{} + al := NewAgentLoop(cfg, msgBus, provider) + + store := media.NewFileMediaStore() + al.SetMediaStore(store) + telegramChannel := &fakeMediaChannel{fakeChannel: fakeChannel{id: "rid-telegram"}} + al.SetChannelManager(newStartedTestChannelManager(t, msgBus, store, "telegram", telegramChannel)) + + imagePath := filepath.Join(tmpDir, "screen-steering.png") + if err := os.WriteFile(imagePath, []byte("fake screenshot"), 0o644); err != nil { + t.Fatalf("WriteFile(imagePath) error = %v", err) + } + + al.RegisterTool(&handledMediaWithSteeringTool{ + store: store, + path: imagePath, + loop: al, + }) + + response, err := al.processMessage(context.Background(), bus.InboundMessage{ + Channel: "telegram", + ChatID: "chat1", + SenderID: "user1", + Content: "take a screenshot of the screen and send it to me", + }) + if err != nil { + t.Fatalf("processMessage() error = %v", err) + } + if response != "Handled the queued steering message." { + t.Fatalf("response = %q, want queued steering response", response) + } + if provider.calls != 2 { + t.Fatalf("expected 2 LLM calls after queued steering, got %d", provider.calls) + } + if len(telegramChannel.sentMedia) != 1 { + t.Fatalf("expected exactly 1 synchronously sent media message, got %d", len(telegramChannel.sentMedia)) + } +} + func TestProcessMessage_MediaArtifactCanBeForwardedBySendFile(t *testing.T) { tmpDir := t.TempDir() cfg := config.DefaultConfig() @@ -637,6 +731,8 @@ func TestProcessMessage_MediaArtifactCanBeForwardedBySendFile(t *testing.T) { store := media.NewFileMediaStore() al.SetMediaStore(store) + telegramChannel := &fakeMediaChannel{fakeChannel: fakeChannel{id: "rid-telegram"}} + al.SetChannelManager(newStartedTestChannelManager(t, msgBus, store, "telegram", telegramChannel)) mediaDir := media.TempDir() if err := os.MkdirAll(mediaDir, 0o700); err != nil { @@ -668,21 +764,19 @@ func TestProcessMessage_MediaArtifactCanBeForwardedBySendFile(t *testing.T) { t.Fatalf("expected 2 LLM calls (artifact + send_file), got %d", provider.calls) } - select { - case mediaMsg := <-msgBus.OutboundMediaChan(): - if mediaMsg.Channel != "telegram" || mediaMsg.ChatID != "chat1" { - t.Fatalf("unexpected outbound media target: %+v", mediaMsg) - } - if len(mediaMsg.Parts) != 1 { - t.Fatalf("expected exactly 1 outbound media part, got %d", len(mediaMsg.Parts)) - } - default: - t.Fatal("expected outbound media from send_file") + if len(telegramChannel.sentMedia) != 1 { + t.Fatalf("expected exactly 1 synchronously sent media message, got %d", len(telegramChannel.sentMedia)) + } + if telegramChannel.sentMedia[0].Channel != "telegram" || telegramChannel.sentMedia[0].ChatID != "chat1" { + t.Fatalf("unexpected sent media target: %+v", telegramChannel.sentMedia[0]) + } + if len(telegramChannel.sentMedia[0].Parts) != 1 { + t.Fatalf("expected exactly 1 sent media part, got %d", len(telegramChannel.sentMedia[0].Parts)) } select { case extra := <-msgBus.OutboundMediaChan(): - t.Fatalf("expected exactly one outbound media delivery, got extra %+v", extra) + t.Fatalf("expected synchronous send_file delivery to bypass async queue, got %+v", extra) default: } } @@ -975,6 +1069,77 @@ func (m *handledMediaTool) Execute(ctx context.Context, args map[string]any) *to return tools.MediaResult("Attachment delivered by tool.", []string{ref}).WithResponseHandled() } +type handledMediaWithSteeringProvider struct { + calls int +} + +func (m *handledMediaWithSteeringProvider) Chat( + ctx context.Context, + messages []providers.Message, + tools []providers.ToolDefinition, + model string, + opts map[string]any, +) (*providers.LLMResponse, error) { + m.calls++ + if m.calls == 1 { + return &providers.LLMResponse{ + Content: "Taking the screenshot now.", + ToolCalls: []providers.ToolCall{{ + ID: "call_handled_media_steering", + Type: "function", + Name: "handled_media_with_steering_tool", + Arguments: map[string]any{}, + }}, + }, nil + } + + for _, msg := range messages { + if msg.Role == "user" && msg.Content == "what about this instead?" { + return &providers.LLMResponse{Content: "Handled the queued steering message."}, nil + } + } + + return nil, fmt.Errorf("provider did not receive queued steering message") +} + +func (m *handledMediaWithSteeringProvider) GetDefaultModel() string { + return "handled-media-with-steering-model" +} + +type handledMediaWithSteeringTool struct { + store media.MediaStore + path string + loop *AgentLoop +} + +func (m *handledMediaWithSteeringTool) Name() string { return "handled_media_with_steering_tool" } +func (m *handledMediaWithSteeringTool) Description() string { + return "Returns handled media and enqueues a steering message during execution" +} + +func (m *handledMediaWithSteeringTool) Parameters() map[string]any { + return map[string]any{ + "type": "object", + "properties": map[string]any{}, + } +} + +func (m *handledMediaWithSteeringTool) Execute(ctx context.Context, args map[string]any) *tools.ToolResult { + if err := m.loop.Steer(providers.Message{Role: "user", Content: "what about this instead?"}); err != nil { + return tools.ErrorResult(err.Error()).WithError(err) + } + + ref, err := m.store.Store(m.path, media.MediaMeta{ + Filename: filepath.Base(m.path), + ContentType: "image/png", + Source: "test:handled_media_with_steering_tool", + }, "test:handled_media_with_steering") + if err != nil { + return tools.ErrorResult(err.Error()).WithError(err) + } + return tools.MediaResult("Attachment delivered by tool.", []string{ref}).WithResponseHandled() +} + type mediaArtifactTool struct { store media.MediaStore path string