mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
fix delivery and steering
This commit is contained in:
+64
-23
@@ -2413,6 +2413,47 @@ turnLoop:
|
||||
if toolResult == nil {
|
||||
toolResult = tools.ErrorResult("hook returned nil tool result")
|
||||
}
|
||||
if len(toolResult.Media) > 0 && toolResult.ResponseHandled {
|
||||
parts := make([]bus.MediaPart, 0, len(toolResult.Media))
|
||||
for _, ref := range toolResult.Media {
|
||||
part := bus.MediaPart{Ref: ref}
|
||||
if al.mediaStore != nil {
|
||||
if _, meta, err := al.mediaStore.ResolveWithMeta(ref); err == nil {
|
||||
part.Filename = meta.Filename
|
||||
part.ContentType = meta.ContentType
|
||||
part.Type = inferMediaType(meta.Filename, meta.ContentType)
|
||||
}
|
||||
}
|
||||
parts = append(parts, part)
|
||||
}
|
||||
outboundMedia := bus.OutboundMediaMessage{
|
||||
Channel: ts.channel,
|
||||
ChatID: ts.chatID,
|
||||
Parts: parts,
|
||||
}
|
||||
if al.channelManager != nil && ts.channel != "" && !constants.IsInternalChannel(ts.channel) {
|
||||
if err := al.channelManager.SendMedia(ctx, outboundMedia); err != nil {
|
||||
logger.WarnCF("agent", "Failed to deliver handled tool media",
|
||||
map[string]any{
|
||||
"agent_id": ts.agent.ID,
|
||||
"tool": toolName,
|
||||
"channel": ts.channel,
|
||||
"chat_id": ts.chatID,
|
||||
"error": err.Error(),
|
||||
})
|
||||
toolResult = tools.ErrorResult(fmt.Sprintf("failed to deliver attachment: %v", err)).WithError(err)
|
||||
}
|
||||
} else if al.bus != nil {
|
||||
al.bus.PublishOutboundMedia(ctx, outboundMedia)
|
||||
// Queuing media is only best-effort; it has not been delivered yet.
|
||||
toolResult.ResponseHandled = false
|
||||
}
|
||||
}
|
||||
|
||||
if len(toolResult.Media) > 0 && !toolResult.ResponseHandled {
|
||||
toolResult.ArtifactTags = buildArtifactTags(al.mediaStore, toolResult.Media)
|
||||
}
|
||||
|
||||
if !toolResult.ResponseHandled {
|
||||
allResponsesHandled = false
|
||||
}
|
||||
@@ -2430,29 +2471,6 @@ turnLoop:
|
||||
})
|
||||
}
|
||||
|
||||
if len(toolResult.Media) > 0 && toolResult.ResponseHandled {
|
||||
parts := make([]bus.MediaPart, 0, len(toolResult.Media))
|
||||
for _, ref := range toolResult.Media {
|
||||
part := bus.MediaPart{Ref: ref}
|
||||
if al.mediaStore != nil {
|
||||
if _, meta, err := al.mediaStore.ResolveWithMeta(ref); err == nil {
|
||||
part.Filename = meta.Filename
|
||||
part.ContentType = meta.ContentType
|
||||
part.Type = inferMediaType(meta.Filename, meta.ContentType)
|
||||
}
|
||||
}
|
||||
parts = append(parts, part)
|
||||
}
|
||||
al.bus.PublishOutboundMedia(ctx, bus.OutboundMediaMessage{
|
||||
Channel: ts.channel,
|
||||
ChatID: ts.chatID,
|
||||
Parts: parts,
|
||||
})
|
||||
}
|
||||
|
||||
if len(toolResult.Media) > 0 && !toolResult.ResponseHandled {
|
||||
toolResult.ArtifactTags = buildArtifactTags(al.mediaStore, toolResult.Media)
|
||||
}
|
||||
contentForLLM := toolResult.ContentForLLM()
|
||||
|
||||
toolResultMsg := providers.Message{
|
||||
@@ -2543,6 +2561,29 @@ turnLoop:
|
||||
}
|
||||
|
||||
if allResponsesHandled {
|
||||
if len(pendingMessages) > 0 {
|
||||
logger.InfoCF("agent", "Pending steering exists after handled tool delivery; continuing turn before finalizing",
|
||||
map[string]any{
|
||||
"agent_id": ts.agent.ID,
|
||||
"steering_count": len(pendingMessages),
|
||||
"session_key": ts.sessionKey,
|
||||
})
|
||||
finalContent = ""
|
||||
goto turnLoop
|
||||
}
|
||||
|
||||
if steerMsgs := al.dequeueSteeringMessagesForScope(ts.sessionKey); len(steerMsgs) > 0 {
|
||||
logger.InfoCF("agent", "Steering arrived after handled tool delivery; continuing turn before finalizing",
|
||||
map[string]any{
|
||||
"agent_id": ts.agent.ID,
|
||||
"steering_count": len(steerMsgs),
|
||||
"session_key": ts.sessionKey,
|
||||
})
|
||||
pendingMessages = append(pendingMessages, steerMsgs...)
|
||||
finalContent = ""
|
||||
goto turnLoop
|
||||
}
|
||||
|
||||
summaryMsg := providers.Message{
|
||||
Role: "assistant",
|
||||
Content: handledToolResponseSummary,
|
||||
|
||||
+184
-19
@@ -33,6 +33,41 @@ func (f *fakeChannel) IsAllowed(string) bool {
|
||||
func (f *fakeChannel) IsAllowedSender(sender bus.SenderInfo) bool { return true }
|
||||
func (f *fakeChannel) ReasoningChannelID() string { return f.id }
|
||||
|
||||
type fakeMediaChannel struct {
|
||||
fakeChannel
|
||||
sentMedia []bus.OutboundMediaMessage
|
||||
}
|
||||
|
||||
func (f *fakeMediaChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error {
|
||||
f.sentMedia = append(f.sentMedia, msg)
|
||||
return nil
|
||||
}
|
||||
|
||||
func newStartedTestChannelManager(
|
||||
t *testing.T,
|
||||
msgBus *bus.MessageBus,
|
||||
store media.MediaStore,
|
||||
name string,
|
||||
ch channels.Channel,
|
||||
) *channels.Manager {
|
||||
t.Helper()
|
||||
|
||||
cm, err := channels.NewManager(&config.Config{}, msgBus, store)
|
||||
if err != nil {
|
||||
t.Fatalf("NewManager() error = %v", err)
|
||||
}
|
||||
cm.RegisterChannel(name, ch)
|
||||
if err := cm.StartAll(context.Background()); err != nil {
|
||||
t.Fatalf("StartAll() error = %v", err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
if err := cm.StopAll(context.Background()); err != nil {
|
||||
t.Fatalf("StopAll() error = %v", err)
|
||||
}
|
||||
})
|
||||
return cm
|
||||
}
|
||||
|
||||
type recordingProvider struct {
|
||||
lastMessages []providers.Message
|
||||
}
|
||||
@@ -554,6 +589,8 @@ func TestProcessMessage_MediaToolHandledSkipsFollowUpLLMAndFinalText(t *testing.
|
||||
|
||||
store := media.NewFileMediaStore()
|
||||
al.SetMediaStore(store)
|
||||
telegramChannel := &fakeMediaChannel{fakeChannel: fakeChannel{id: "rid-telegram"}}
|
||||
al.SetChannelManager(newStartedTestChannelManager(t, msgBus, store, "telegram", telegramChannel))
|
||||
|
||||
imagePath := filepath.Join(tmpDir, "screen.png")
|
||||
if err := os.WriteFile(imagePath, []byte("fake screenshot"), 0o644); err != nil {
|
||||
@@ -587,16 +624,20 @@ func TestProcessMessage_MediaToolHandledSkipsFollowUpLLMAndFinalText(t *testing.
|
||||
t.Fatal("expected tools to be available on the first LLM call")
|
||||
}
|
||||
|
||||
if len(telegramChannel.sentMedia) != 1 {
|
||||
t.Fatalf("expected exactly 1 synchronously sent media message, got %d", len(telegramChannel.sentMedia))
|
||||
}
|
||||
if telegramChannel.sentMedia[0].Channel != "telegram" || telegramChannel.sentMedia[0].ChatID != "chat1" {
|
||||
t.Fatalf("unexpected sent media target: %+v", telegramChannel.sentMedia[0])
|
||||
}
|
||||
if len(telegramChannel.sentMedia[0].Parts) != 1 {
|
||||
t.Fatalf("expected exactly 1 sent media part, got %d", len(telegramChannel.sentMedia[0].Parts))
|
||||
}
|
||||
|
||||
select {
|
||||
case mediaMsg := <-msgBus.OutboundMediaChan():
|
||||
if mediaMsg.Channel != "telegram" || mediaMsg.ChatID != "chat1" {
|
||||
t.Fatalf("unexpected outbound media target: %+v", mediaMsg)
|
||||
}
|
||||
if len(mediaMsg.Parts) != 1 {
|
||||
t.Fatalf("expected exactly 1 outbound media part, got %d", len(mediaMsg.Parts))
|
||||
}
|
||||
case extra := <-msgBus.OutboundMediaChan():
|
||||
t.Fatalf("expected handled media to bypass async queue, got %+v", extra)
|
||||
default:
|
||||
t.Fatal("expected outbound media message to be published")
|
||||
}
|
||||
|
||||
defaultAgent := al.GetRegistry().GetDefaultAgent()
|
||||
@@ -623,6 +664,59 @@ func TestProcessMessage_MediaToolHandledSkipsFollowUpLLMAndFinalText(t *testing.
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessMessage_HandledToolProcessesQueuedSteeringBeforeReturning(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
ModelName: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider := &handledMediaWithSteeringProvider{}
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
|
||||
store := media.NewFileMediaStore()
|
||||
al.SetMediaStore(store)
|
||||
telegramChannel := &fakeMediaChannel{fakeChannel: fakeChannel{id: "rid-telegram"}}
|
||||
al.SetChannelManager(newStartedTestChannelManager(t, msgBus, store, "telegram", telegramChannel))
|
||||
|
||||
imagePath := filepath.Join(tmpDir, "screen-steering.png")
|
||||
if err := os.WriteFile(imagePath, []byte("fake screenshot"), 0o644); err != nil {
|
||||
t.Fatalf("WriteFile(imagePath) error = %v", err)
|
||||
}
|
||||
|
||||
al.RegisterTool(&handledMediaWithSteeringTool{
|
||||
store: store,
|
||||
path: imagePath,
|
||||
loop: al,
|
||||
})
|
||||
|
||||
response, err := al.processMessage(context.Background(), bus.InboundMessage{
|
||||
Channel: "telegram",
|
||||
ChatID: "chat1",
|
||||
SenderID: "user1",
|
||||
Content: "take a screenshot of the screen and send it to me",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("processMessage() error = %v", err)
|
||||
}
|
||||
if response != "Handled the queued steering message." {
|
||||
t.Fatalf("response = %q, want queued steering response", response)
|
||||
}
|
||||
if provider.calls != 2 {
|
||||
t.Fatalf("expected 2 LLM calls after queued steering, got %d", provider.calls)
|
||||
}
|
||||
if len(telegramChannel.sentMedia) != 1 {
|
||||
t.Fatalf("expected exactly 1 synchronously sent media message, got %d", len(telegramChannel.sentMedia))
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessMessage_MediaArtifactCanBeForwardedBySendFile(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
cfg := config.DefaultConfig()
|
||||
@@ -637,6 +731,8 @@ func TestProcessMessage_MediaArtifactCanBeForwardedBySendFile(t *testing.T) {
|
||||
|
||||
store := media.NewFileMediaStore()
|
||||
al.SetMediaStore(store)
|
||||
telegramChannel := &fakeMediaChannel{fakeChannel: fakeChannel{id: "rid-telegram"}}
|
||||
al.SetChannelManager(newStartedTestChannelManager(t, msgBus, store, "telegram", telegramChannel))
|
||||
|
||||
mediaDir := media.TempDir()
|
||||
if err := os.MkdirAll(mediaDir, 0o700); err != nil {
|
||||
@@ -668,21 +764,19 @@ func TestProcessMessage_MediaArtifactCanBeForwardedBySendFile(t *testing.T) {
|
||||
t.Fatalf("expected 2 LLM calls (artifact + send_file), got %d", provider.calls)
|
||||
}
|
||||
|
||||
select {
|
||||
case mediaMsg := <-msgBus.OutboundMediaChan():
|
||||
if mediaMsg.Channel != "telegram" || mediaMsg.ChatID != "chat1" {
|
||||
t.Fatalf("unexpected outbound media target: %+v", mediaMsg)
|
||||
}
|
||||
if len(mediaMsg.Parts) != 1 {
|
||||
t.Fatalf("expected exactly 1 outbound media part, got %d", len(mediaMsg.Parts))
|
||||
}
|
||||
default:
|
||||
t.Fatal("expected outbound media from send_file")
|
||||
if len(telegramChannel.sentMedia) != 1 {
|
||||
t.Fatalf("expected exactly 1 synchronously sent media message, got %d", len(telegramChannel.sentMedia))
|
||||
}
|
||||
if telegramChannel.sentMedia[0].Channel != "telegram" || telegramChannel.sentMedia[0].ChatID != "chat1" {
|
||||
t.Fatalf("unexpected sent media target: %+v", telegramChannel.sentMedia[0])
|
||||
}
|
||||
if len(telegramChannel.sentMedia[0].Parts) != 1 {
|
||||
t.Fatalf("expected exactly 1 sent media part, got %d", len(telegramChannel.sentMedia[0].Parts))
|
||||
}
|
||||
|
||||
select {
|
||||
case extra := <-msgBus.OutboundMediaChan():
|
||||
t.Fatalf("expected exactly one outbound media delivery, got extra %+v", extra)
|
||||
t.Fatalf("expected synchronous send_file delivery to bypass async queue, got %+v", extra)
|
||||
default:
|
||||
}
|
||||
}
|
||||
@@ -975,6 +1069,77 @@ func (m *handledMediaTool) Execute(ctx context.Context, args map[string]any) *to
|
||||
return tools.MediaResult("Attachment delivered by tool.", []string{ref}).WithResponseHandled()
|
||||
}
|
||||
|
||||
type handledMediaWithSteeringProvider struct {
|
||||
calls int
|
||||
}
|
||||
|
||||
func (m *handledMediaWithSteeringProvider) Chat(
|
||||
ctx context.Context,
|
||||
messages []providers.Message,
|
||||
tools []providers.ToolDefinition,
|
||||
model string,
|
||||
opts map[string]any,
|
||||
) (*providers.LLMResponse, error) {
|
||||
m.calls++
|
||||
if m.calls == 1 {
|
||||
return &providers.LLMResponse{
|
||||
Content: "Taking the screenshot now.",
|
||||
ToolCalls: []providers.ToolCall{{
|
||||
ID: "call_handled_media_steering",
|
||||
Type: "function",
|
||||
Name: "handled_media_with_steering_tool",
|
||||
Arguments: map[string]any{},
|
||||
}},
|
||||
}, nil
|
||||
}
|
||||
|
||||
for _, msg := range messages {
|
||||
if msg.Role == "user" && msg.Content == "what about this instead?" {
|
||||
return &providers.LLMResponse{Content: "Handled the queued steering message."}, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("provider did not receive queued steering message")
|
||||
}
|
||||
|
||||
func (m *handledMediaWithSteeringProvider) GetDefaultModel() string {
|
||||
return "handled-media-with-steering-model"
|
||||
}
|
||||
|
||||
type handledMediaWithSteeringTool struct {
|
||||
store media.MediaStore
|
||||
path string
|
||||
loop *AgentLoop
|
||||
}
|
||||
|
||||
func (m *handledMediaWithSteeringTool) Name() string { return "handled_media_with_steering_tool" }
|
||||
func (m *handledMediaWithSteeringTool) Description() string {
|
||||
return "Returns handled media and enqueues a steering message during execution"
|
||||
}
|
||||
|
||||
func (m *handledMediaWithSteeringTool) Parameters() map[string]any {
|
||||
return map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{},
|
||||
}
|
||||
}
|
||||
|
||||
func (m *handledMediaWithSteeringTool) Execute(ctx context.Context, args map[string]any) *tools.ToolResult {
|
||||
if err := m.loop.Steer(providers.Message{Role: "user", Content: "what about this instead?"}); err != nil {
|
||||
return tools.ErrorResult(err.Error()).WithError(err)
|
||||
}
|
||||
|
||||
ref, err := m.store.Store(m.path, media.MediaMeta{
|
||||
Filename: filepath.Base(m.path),
|
||||
ContentType: "image/png",
|
||||
Source: "test:handled_media_with_steering_tool",
|
||||
}, "test:handled_media_with_steering")
|
||||
if err != nil {
|
||||
return tools.ErrorResult(err.Error()).WithError(err)
|
||||
}
|
||||
return tools.MediaResult("Attachment delivered by tool.", []string{ref}).WithResponseHandled()
|
||||
}
|
||||
|
||||
type mediaArtifactTool struct {
|
||||
store media.MediaStore
|
||||
path string
|
||||
|
||||
Reference in New Issue
Block a user