fix delivery and steering

This commit is contained in:
afjcjsbx
2026-03-23 14:09:52 +01:00
parent 8ed171dbe6
commit 5d5536a1a6
2 changed files with 248 additions and 42 deletions
+64 -23
View File
@@ -2413,6 +2413,47 @@ turnLoop:
if toolResult == nil {
toolResult = tools.ErrorResult("hook returned nil tool result")
}
if len(toolResult.Media) > 0 && toolResult.ResponseHandled {
parts := make([]bus.MediaPart, 0, len(toolResult.Media))
for _, ref := range toolResult.Media {
part := bus.MediaPart{Ref: ref}
if al.mediaStore != nil {
if _, meta, err := al.mediaStore.ResolveWithMeta(ref); err == nil {
part.Filename = meta.Filename
part.ContentType = meta.ContentType
part.Type = inferMediaType(meta.Filename, meta.ContentType)
}
}
parts = append(parts, part)
}
outboundMedia := bus.OutboundMediaMessage{
Channel: ts.channel,
ChatID: ts.chatID,
Parts: parts,
}
if al.channelManager != nil && ts.channel != "" && !constants.IsInternalChannel(ts.channel) {
if err := al.channelManager.SendMedia(ctx, outboundMedia); err != nil {
logger.WarnCF("agent", "Failed to deliver handled tool media",
map[string]any{
"agent_id": ts.agent.ID,
"tool": toolName,
"channel": ts.channel,
"chat_id": ts.chatID,
"error": err.Error(),
})
toolResult = tools.ErrorResult(fmt.Sprintf("failed to deliver attachment: %v", err)).WithError(err)
}
} else if al.bus != nil {
al.bus.PublishOutboundMedia(ctx, outboundMedia)
// Queuing media is only best-effort; it has not been delivered yet.
toolResult.ResponseHandled = false
}
}
if len(toolResult.Media) > 0 && !toolResult.ResponseHandled {
toolResult.ArtifactTags = buildArtifactTags(al.mediaStore, toolResult.Media)
}
if !toolResult.ResponseHandled {
allResponsesHandled = false
}
@@ -2430,29 +2471,6 @@ turnLoop:
})
}
if len(toolResult.Media) > 0 && toolResult.ResponseHandled {
parts := make([]bus.MediaPart, 0, len(toolResult.Media))
for _, ref := range toolResult.Media {
part := bus.MediaPart{Ref: ref}
if al.mediaStore != nil {
if _, meta, err := al.mediaStore.ResolveWithMeta(ref); err == nil {
part.Filename = meta.Filename
part.ContentType = meta.ContentType
part.Type = inferMediaType(meta.Filename, meta.ContentType)
}
}
parts = append(parts, part)
}
al.bus.PublishOutboundMedia(ctx, bus.OutboundMediaMessage{
Channel: ts.channel,
ChatID: ts.chatID,
Parts: parts,
})
}
if len(toolResult.Media) > 0 && !toolResult.ResponseHandled {
toolResult.ArtifactTags = buildArtifactTags(al.mediaStore, toolResult.Media)
}
contentForLLM := toolResult.ContentForLLM()
toolResultMsg := providers.Message{
@@ -2543,6 +2561,29 @@ turnLoop:
}
if allResponsesHandled {
if len(pendingMessages) > 0 {
logger.InfoCF("agent", "Pending steering exists after handled tool delivery; continuing turn before finalizing",
map[string]any{
"agent_id": ts.agent.ID,
"steering_count": len(pendingMessages),
"session_key": ts.sessionKey,
})
finalContent = ""
goto turnLoop
}
if steerMsgs := al.dequeueSteeringMessagesForScope(ts.sessionKey); len(steerMsgs) > 0 {
logger.InfoCF("agent", "Steering arrived after handled tool delivery; continuing turn before finalizing",
map[string]any{
"agent_id": ts.agent.ID,
"steering_count": len(steerMsgs),
"session_key": ts.sessionKey,
})
pendingMessages = append(pendingMessages, steerMsgs...)
finalContent = ""
goto turnLoop
}
summaryMsg := providers.Message{
Role: "assistant",
Content: handledToolResponseSummary,
+184 -19
View File
@@ -33,6 +33,41 @@ func (f *fakeChannel) IsAllowed(string) bool {
func (f *fakeChannel) IsAllowedSender(sender bus.SenderInfo) bool { return true }
func (f *fakeChannel) ReasoningChannelID() string { return f.id }
type fakeMediaChannel struct {
fakeChannel
sentMedia []bus.OutboundMediaMessage
}
func (f *fakeMediaChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error {
f.sentMedia = append(f.sentMedia, msg)
return nil
}
func newStartedTestChannelManager(
t *testing.T,
msgBus *bus.MessageBus,
store media.MediaStore,
name string,
ch channels.Channel,
) *channels.Manager {
t.Helper()
cm, err := channels.NewManager(&config.Config{}, msgBus, store)
if err != nil {
t.Fatalf("NewManager() error = %v", err)
}
cm.RegisterChannel(name, ch)
if err := cm.StartAll(context.Background()); err != nil {
t.Fatalf("StartAll() error = %v", err)
}
t.Cleanup(func() {
if err := cm.StopAll(context.Background()); err != nil {
t.Fatalf("StopAll() error = %v", err)
}
})
return cm
}
type recordingProvider struct {
lastMessages []providers.Message
}
@@ -554,6 +589,8 @@ func TestProcessMessage_MediaToolHandledSkipsFollowUpLLMAndFinalText(t *testing.
store := media.NewFileMediaStore()
al.SetMediaStore(store)
telegramChannel := &fakeMediaChannel{fakeChannel: fakeChannel{id: "rid-telegram"}}
al.SetChannelManager(newStartedTestChannelManager(t, msgBus, store, "telegram", telegramChannel))
imagePath := filepath.Join(tmpDir, "screen.png")
if err := os.WriteFile(imagePath, []byte("fake screenshot"), 0o644); err != nil {
@@ -587,16 +624,20 @@ func TestProcessMessage_MediaToolHandledSkipsFollowUpLLMAndFinalText(t *testing.
t.Fatal("expected tools to be available on the first LLM call")
}
if len(telegramChannel.sentMedia) != 1 {
t.Fatalf("expected exactly 1 synchronously sent media message, got %d", len(telegramChannel.sentMedia))
}
if telegramChannel.sentMedia[0].Channel != "telegram" || telegramChannel.sentMedia[0].ChatID != "chat1" {
t.Fatalf("unexpected sent media target: %+v", telegramChannel.sentMedia[0])
}
if len(telegramChannel.sentMedia[0].Parts) != 1 {
t.Fatalf("expected exactly 1 sent media part, got %d", len(telegramChannel.sentMedia[0].Parts))
}
select {
case mediaMsg := <-msgBus.OutboundMediaChan():
if mediaMsg.Channel != "telegram" || mediaMsg.ChatID != "chat1" {
t.Fatalf("unexpected outbound media target: %+v", mediaMsg)
}
if len(mediaMsg.Parts) != 1 {
t.Fatalf("expected exactly 1 outbound media part, got %d", len(mediaMsg.Parts))
}
case extra := <-msgBus.OutboundMediaChan():
t.Fatalf("expected handled media to bypass async queue, got %+v", extra)
default:
t.Fatal("expected outbound media message to be published")
}
defaultAgent := al.GetRegistry().GetDefaultAgent()
@@ -623,6 +664,59 @@ func TestProcessMessage_MediaToolHandledSkipsFollowUpLLMAndFinalText(t *testing.
}
}
func TestProcessMessage_HandledToolProcessesQueuedSteeringBeforeReturning(t *testing.T) {
tmpDir := t.TempDir()
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
ModelName: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
msgBus := bus.NewMessageBus()
provider := &handledMediaWithSteeringProvider{}
al := NewAgentLoop(cfg, msgBus, provider)
store := media.NewFileMediaStore()
al.SetMediaStore(store)
telegramChannel := &fakeMediaChannel{fakeChannel: fakeChannel{id: "rid-telegram"}}
al.SetChannelManager(newStartedTestChannelManager(t, msgBus, store, "telegram", telegramChannel))
imagePath := filepath.Join(tmpDir, "screen-steering.png")
if err := os.WriteFile(imagePath, []byte("fake screenshot"), 0o644); err != nil {
t.Fatalf("WriteFile(imagePath) error = %v", err)
}
al.RegisterTool(&handledMediaWithSteeringTool{
store: store,
path: imagePath,
loop: al,
})
response, err := al.processMessage(context.Background(), bus.InboundMessage{
Channel: "telegram",
ChatID: "chat1",
SenderID: "user1",
Content: "take a screenshot of the screen and send it to me",
})
if err != nil {
t.Fatalf("processMessage() error = %v", err)
}
if response != "Handled the queued steering message." {
t.Fatalf("response = %q, want queued steering response", response)
}
if provider.calls != 2 {
t.Fatalf("expected 2 LLM calls after queued steering, got %d", provider.calls)
}
if len(telegramChannel.sentMedia) != 1 {
t.Fatalf("expected exactly 1 synchronously sent media message, got %d", len(telegramChannel.sentMedia))
}
}
func TestProcessMessage_MediaArtifactCanBeForwardedBySendFile(t *testing.T) {
tmpDir := t.TempDir()
cfg := config.DefaultConfig()
@@ -637,6 +731,8 @@ func TestProcessMessage_MediaArtifactCanBeForwardedBySendFile(t *testing.T) {
store := media.NewFileMediaStore()
al.SetMediaStore(store)
telegramChannel := &fakeMediaChannel{fakeChannel: fakeChannel{id: "rid-telegram"}}
al.SetChannelManager(newStartedTestChannelManager(t, msgBus, store, "telegram", telegramChannel))
mediaDir := media.TempDir()
if err := os.MkdirAll(mediaDir, 0o700); err != nil {
@@ -668,21 +764,19 @@ func TestProcessMessage_MediaArtifactCanBeForwardedBySendFile(t *testing.T) {
t.Fatalf("expected 2 LLM calls (artifact + send_file), got %d", provider.calls)
}
select {
case mediaMsg := <-msgBus.OutboundMediaChan():
if mediaMsg.Channel != "telegram" || mediaMsg.ChatID != "chat1" {
t.Fatalf("unexpected outbound media target: %+v", mediaMsg)
}
if len(mediaMsg.Parts) != 1 {
t.Fatalf("expected exactly 1 outbound media part, got %d", len(mediaMsg.Parts))
}
default:
t.Fatal("expected outbound media from send_file")
if len(telegramChannel.sentMedia) != 1 {
t.Fatalf("expected exactly 1 synchronously sent media message, got %d", len(telegramChannel.sentMedia))
}
if telegramChannel.sentMedia[0].Channel != "telegram" || telegramChannel.sentMedia[0].ChatID != "chat1" {
t.Fatalf("unexpected sent media target: %+v", telegramChannel.sentMedia[0])
}
if len(telegramChannel.sentMedia[0].Parts) != 1 {
t.Fatalf("expected exactly 1 sent media part, got %d", len(telegramChannel.sentMedia[0].Parts))
}
select {
case extra := <-msgBus.OutboundMediaChan():
t.Fatalf("expected exactly one outbound media delivery, got extra %+v", extra)
t.Fatalf("expected synchronous send_file delivery to bypass async queue, got %+v", extra)
default:
}
}
@@ -975,6 +1069,77 @@ func (m *handledMediaTool) Execute(ctx context.Context, args map[string]any) *to
return tools.MediaResult("Attachment delivered by tool.", []string{ref}).WithResponseHandled()
}
type handledMediaWithSteeringProvider struct {
calls int
}
func (m *handledMediaWithSteeringProvider) Chat(
ctx context.Context,
messages []providers.Message,
tools []providers.ToolDefinition,
model string,
opts map[string]any,
) (*providers.LLMResponse, error) {
m.calls++
if m.calls == 1 {
return &providers.LLMResponse{
Content: "Taking the screenshot now.",
ToolCalls: []providers.ToolCall{{
ID: "call_handled_media_steering",
Type: "function",
Name: "handled_media_with_steering_tool",
Arguments: map[string]any{},
}},
}, nil
}
for _, msg := range messages {
if msg.Role == "user" && msg.Content == "what about this instead?" {
return &providers.LLMResponse{Content: "Handled the queued steering message."}, nil
}
}
return nil, fmt.Errorf("provider did not receive queued steering message")
}
func (m *handledMediaWithSteeringProvider) GetDefaultModel() string {
return "handled-media-with-steering-model"
}
type handledMediaWithSteeringTool struct {
store media.MediaStore
path string
loop *AgentLoop
}
func (m *handledMediaWithSteeringTool) Name() string { return "handled_media_with_steering_tool" }
func (m *handledMediaWithSteeringTool) Description() string {
return "Returns handled media and enqueues a steering message during execution"
}
func (m *handledMediaWithSteeringTool) Parameters() map[string]any {
return map[string]any{
"type": "object",
"properties": map[string]any{},
}
}
func (m *handledMediaWithSteeringTool) Execute(ctx context.Context, args map[string]any) *tools.ToolResult {
if err := m.loop.Steer(providers.Message{Role: "user", Content: "what about this instead?"}); err != nil {
return tools.ErrorResult(err.Error()).WithError(err)
}
ref, err := m.store.Store(m.path, media.MediaMeta{
Filename: filepath.Base(m.path),
ContentType: "image/png",
Source: "test:handled_media_with_steering_tool",
}, "test:handled_media_with_steering")
if err != nil {
return tools.ErrorResult(err.Error()).WithError(err)
}
return tools.MediaResult("Attachment delivered by tool.", []string{ref}).WithResponseHandled()
}
type mediaArtifactTool struct {
store media.MediaStore
path string