diff --git a/docs/guides/configuration.md b/docs/guides/configuration.md index ef4802a78..d582676e7 100644 --- a/docs/guides/configuration.md +++ b/docs/guides/configuration.md @@ -400,6 +400,7 @@ Even with `restrict_to_workspace: false`, the `exec` tool blocks these dangerous |------------|------|---------|-------------| | `tools.allow_read_paths` | string[] | `[]` | Additional paths allowed for reading outside workspace | | `tools.allow_write_paths` | string[] | `[]` | Additional paths allowed for writing outside workspace | +| `tools.message.media_enabled` | bool | `false` | Allows the `message` tool to attach local media files by path. This is separate from `tools.send_file.enabled`; enable it only when unified text/media/caption delivery is intended. | ### Read File Mode diff --git a/pkg/agent/agent_init.go b/pkg/agent/agent_init.go index 50f0227a1..17629892d 100644 --- a/pkg/agent/agent_init.go +++ b/pkg/agent/agent_init.go @@ -161,26 +161,58 @@ func registerSharedTools( // Message tool if cfg.Tools.IsToolEnabled("message") { messageTool := tools.NewMessageTool() + if cfg.Tools.Message.MediaEnabled { + messageTool.ConfigureLocalMedia( + agent.Workspace, + cfg.Agents.Defaults.RestrictToWorkspace, + cfg.Agents.Defaults.GetMaxMediaSize(), + allowReadPaths, + ) + } messageTool.SetSendCallback(func( ctx context.Context, channel, chatID, content, replyToMessageID string, + mediaParts []bus.MediaPart, ) error { - pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second) - defer pubCancel() outboundCtx := bus.NewOutboundContext(channel, chatID, replyToMessageID) outboundAgentID, outboundSessionKey, outboundScope := outboundTurnMetadata( tools.ToolAgentID(ctx), tools.ToolSessionKey(ctx), tools.ToolSessionScope(ctx), ) - return msgBus.PublishOutbound(pubCtx, bus.OutboundMessage{ + if len(mediaParts) > 0 { + outboundMedia := bus.OutboundMediaMessage{ + Channel: channel, + ChatID: chatID, + Context: outboundCtx, + AgentID: outboundAgentID, + SessionKey: outboundSessionKey, + Scope: outboundScope, + Parts: mediaParts, + } + if al.channelManager != nil && channel != "" { + return al.channelManager.SendMedia(ctx, outboundMedia) + } + pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer pubCancel() + return msgBus.PublishOutboundMedia(pubCtx, outboundMedia) + } + outboundMessage := bus.OutboundMessage{ + Channel: channel, + ChatID: chatID, Context: outboundCtx, AgentID: outboundAgentID, SessionKey: outboundSessionKey, Scope: outboundScope, Content: content, ReplyToMessageID: replyToMessageID, - }) + } + if al.channelManager != nil && channel != "" { + return al.channelManager.SendMessage(ctx, outboundMessage) + } + pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer pubCancel() + return msgBus.PublishOutbound(pubCtx, outboundMessage) }) agent.Tools.Register(messageTool) } diff --git a/pkg/agent/agent_test.go b/pkg/agent/agent_test.go index 08cf71d79..c3562bdaa 100644 --- a/pkg/agent/agent_test.go +++ b/pkg/agent/agent_test.go @@ -377,7 +377,11 @@ func TestPublishResponseIfNeeded_DismissesToolFeedbackWhenMessageToolAlreadySent t.Fatal("expected default agent") } mt := tools.NewMessageTool() - mt.SetSendCallback(func(ctx context.Context, channel, chatID, content, replyToMessageID string) error { + mt.SetSendCallback(func( + ctx context.Context, + channel, chatID, content, replyToMessageID string, + mediaParts []bus.MediaPart, + ) error { return nil }) defaultAgent.Tools.Register(mt) diff --git a/pkg/channels/feishu/feishu_64.go b/pkg/channels/feishu/feishu_64.go index d09c021c7..2fef72273 100644 --- a/pkg/channels/feishu/feishu_64.go +++ b/pkg/channels/feishu/feishu_64.go @@ -52,6 +52,8 @@ type FeishuChannel struct { progress *channels.ToolFeedbackAnimator deleteMessageFn func(context.Context, string, string) error + sendMediaPartFn func(context.Context, string, bus.MediaPart, media.MediaStore) error + sendTextFn func(context.Context, string, string) (string, error) } type cachedMessage struct { @@ -78,6 +80,8 @@ func NewFeishuChannel(bc *config.Channel, cfg *config.FeishuSettings, bus *bus.M client: lark.NewClient(cfg.AppID, cfg.AppSecret.String(), opts...), } ch.deleteMessageFn = ch.deleteMessageAPI + ch.sendMediaPartFn = ch.sendMediaPart + ch.sendTextFn = ch.sendText ch.progress = channels.NewToolFeedbackAnimator(ch.EditMessage) ch.SetOwner(ch) return ch, nil @@ -497,8 +501,16 @@ func (c *FeishuChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMess return nil, fmt.Errorf("no media store available: %w", channels.ErrSendFailed) } + caption := firstMediaCaption(msg.Parts) + sentAny := false for _, part := range msg.Parts { - if err := c.sendMediaPart(ctx, msg.ChatID, part, store); err != nil { + if err := c.sendMediaPartFn(ctx, msg.ChatID, part, store); err != nil { + return nil, err + } + sentAny = true + } + if sentAny && caption != "" { + if _, err := c.sendTextFn(ctx, msg.ChatID, caption); err != nil { return nil, err } } @@ -557,6 +569,15 @@ func (c *FeishuChannel) sendMediaPart( return nil } +func firstMediaCaption(parts []bus.MediaPart) string { + for _, part := range parts { + if caption := strings.TrimSpace(part.Caption); caption != "" { + return caption + } + } + return "" +} + // --- Inbound message handling --- func (c *FeishuChannel) handleMessageReceive(ctx context.Context, event *larkim.P2MessageReceiveV1) error { diff --git a/pkg/channels/feishu/feishu_64_test.go b/pkg/channels/feishu/feishu_64_test.go index d256325ad..1dbacab89 100644 --- a/pkg/channels/feishu/feishu_64_test.go +++ b/pkg/channels/feishu/feishu_64_test.go @@ -9,7 +9,9 @@ import ( larkim "github.com/larksuite/oapi-sdk-go/v3/service/im/v1" + "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/media" ) func TestExtractContent(t *testing.T) { @@ -319,6 +321,43 @@ func TestFinalizeTrackedToolFeedbackMessage_ClearAfterSuccessfulEdit(t *testing. } } +func TestSendMedia_SendsCaptionFallbackAfterMedia(t *testing.T) { + ch := &FeishuChannel{ + BaseChannel: channels.NewBaseChannel("feishu", nil, nil, nil), + progress: channels.NewToolFeedbackAnimator(nil), + } + ch.SetRunning(true) + ch.SetMediaStore(media.NewFileMediaStore()) + + var mediaOrder []string + var textCalls []string + ch.sendMediaPartFn = func(ctx context.Context, chatID string, part bus.MediaPart, store media.MediaStore) error { + mediaOrder = append(mediaOrder, part.Type) + return nil + } + ch.sendTextFn = func(ctx context.Context, chatID, text string) (string, error) { + textCalls = append(textCalls, chatID+"|"+text) + return "msg-1", nil + } + + _, err := ch.SendMedia(context.Background(), bus.OutboundMediaMessage{ + ChatID: "oc_123", + Parts: []bus.MediaPart{ + {Type: "image", Caption: "shared caption"}, + {Type: "file"}, + }, + }) + if err != nil { + t.Fatalf("SendMedia() error = %v", err) + } + if len(mediaOrder) != 2 { + t.Fatalf("media sends = %v, want 2 sends", mediaOrder) + } + if len(textCalls) != 1 || textCalls[0] != "oc_123|shared caption" { + t.Fatalf("textCalls = %v, want [oc_123|shared caption]", textCalls) + } +} + func TestFinalizeTrackedToolFeedbackMessage_StopsTrackingBeforeEdit(t *testing.T) { ch := &FeishuChannel{ progress: channels.NewToolFeedbackAnimator(nil), diff --git a/pkg/channels/pico/pico_test.go b/pkg/channels/pico/pico_test.go index a793d7ad7..9cdf79044 100644 --- a/pkg/channels/pico/pico_test.go +++ b/pkg/channels/pico/pico_test.go @@ -835,6 +835,75 @@ func TestSendMedia_DismissesTrackedToolFeedbackMessage(t *testing.T) { } } +func TestSendMedia_IncludesCaptionAndAttachmentsInSinglePayload(t *testing.T) { + ch := newTestPicoChannel(t) + store := media.NewFileMediaStore() + ch.SetMediaStore(store) + + if err := ch.Start(context.Background()); err != nil { + t.Fatalf("Start() error = %v", err) + } + defer ch.Stop(context.Background()) + + clientConn, received, cleanup := newTestPicoWebSocket(t) + defer cleanup() + ch.addConnForTest(&picoConn{id: "conn-1", conn: clientConn, sessionID: "sess-1"}) + + localPath := filepath.Join(t.TempDir(), "photo.png") + if err := os.WriteFile(localPath, []byte("png-body"), 0o600); err != nil { + t.Fatalf("WriteFile() error = %v", err) + } + + ref, err := store.Store(localPath, media.MediaMeta{ + Filename: "photo.png", + ContentType: "image/png", + }, "test-scope") + if err != nil { + t.Fatalf("Store() error = %v", err) + } + + _, err = ch.SendMedia(context.Background(), bus.OutboundMediaMessage{ + ChatID: "pico:sess-1", + Parts: []bus.MediaPart{{ + Ref: ref, + Type: "image", + Filename: "photo.png", + ContentType: "image/png", + Caption: "recipe translation", + }}, + }) + if err != nil { + t.Fatalf("SendMedia() error = %v", err) + } + + select { + case msg := <-received: + if msg.Type != TypeMessageCreate { + t.Fatalf("message type = %q, want %q", msg.Type, TypeMessageCreate) + } + payload := msg.Payload + if got := payload[PayloadKeyContent]; got != "recipe translation" { + t.Fatalf("content = %#v, want %q", got, "recipe translation") + } + rawAttachments, ok := payload["attachments"].([]any) + if !ok || len(rawAttachments) != 1 { + t.Fatalf("attachments = %#v, want 1 attachment", payload["attachments"]) + } + attachment, ok := rawAttachments[0].(map[string]any) + if !ok { + t.Fatalf("attachment = %#v, want map", rawAttachments[0]) + } + if got := attachment["type"]; got != "image" { + t.Fatalf("attachment type = %#v, want image", got) + } + if got := attachment["filename"]; got != "photo.png" { + t.Fatalf("attachment filename = %#v, want photo.png", got) + } + case <-time.After(time.Second): + t.Fatal("expected media payload to be delivered") + } +} + func TestPicoDownloadURLForRef(t *testing.T) { got, err := picoDownloadURLForRef("media://attachment-1") if err != nil { diff --git a/pkg/channels/slack/slack.go b/pkg/channels/slack/slack.go index fa62a4605..b021feda9 100644 --- a/pkg/channels/slack/slack.go +++ b/pkg/channels/slack/slack.go @@ -29,6 +29,8 @@ type SlackChannel struct { ctx context.Context cancel context.CancelFunc pendingAcks sync.Map + uploadFileFn func(context.Context, slack.UploadFileParameters) error + postTextFn func(context.Context, string, string, string) error } type slackMessageRef struct { @@ -63,6 +65,18 @@ func NewSlackChannel( config: cfg, api: api, socketClient: socketClient, + uploadFileFn: func(ctx context.Context, params slack.UploadFileParameters) error { + _, err := api.UploadFileContext(ctx, params) + return err + }, + postTextFn: func(ctx context.Context, channelID, threadTS, text string) error { + opts := []slack.MsgOption{slack.MsgOptionText(text, false)} + if threadTS != "" { + opts = append(opts, slack.MsgOptionTS(threadTS)) + } + _, _, err := api.PostMessageContext(ctx, channelID, opts...) + return err + }, }, nil } @@ -171,6 +185,8 @@ func (c *SlackChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessa return nil, fmt.Errorf("no media store available: %w", channels.ErrSendFailed) } + caption := slackFirstMediaCaption(msg.Parts) + sentAny := false for _, part := range msg.Parts { localPath, err := store.Resolve(part.Ref) if err != nil { @@ -191,7 +207,7 @@ func (c *SlackChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessa title = filename } - _, err = c.api.UploadFileContext(ctx, slack.UploadFileParameters{ + err = c.uploadFileFn(ctx, slack.UploadFileParameters{ Channel: channelID, ThreadTimestamp: threadTS, File: localPath, @@ -205,6 +221,13 @@ func (c *SlackChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessa }) return nil, fmt.Errorf("slack send media: %w", channels.ErrTemporary) } + sentAny = true + } + + if sentAny && caption != "" { + if err := c.postTextFn(ctx, channelID, threadTS, caption); err != nil { + return nil, fmt.Errorf("slack send media caption fallback: %w", channels.ErrTemporary) + } } // UploadFile does not expose the posted message timestamp in its @@ -212,6 +235,15 @@ func (c *SlackChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessa return nil, nil } +func slackFirstMediaCaption(parts []bus.MediaPart) string { + for _, part := range parts { + if caption := strings.TrimSpace(part.Caption); caption != "" { + return caption + } + } + return "" +} + // ReactToMessage implements channels.ReactionCapable. // It adds an "eyes" (👀) reaction to the inbound message and returns an undo function // that removes the reaction. diff --git a/pkg/channels/slack/slack_test.go b/pkg/channels/slack/slack_test.go index a72521d67..b85f3f028 100644 --- a/pkg/channels/slack/slack_test.go +++ b/pkg/channels/slack/slack_test.go @@ -1,10 +1,17 @@ package slack import ( + "context" + "os" + "path/filepath" "testing" + slacksdk "github.com/slack-go/slack" + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/media" ) func TestParseSlackChatID(t *testing.T) { @@ -184,3 +191,74 @@ func TestSlackChannelIsAllowed(t *testing.T) { } }) } + +func TestSendMedia_SendsCaptionFallbackAfterUploads(t *testing.T) { + ch := &SlackChannel{ + BaseChannel: channels.NewBaseChannel("slack", nil, nil, nil), + } + ch.SetRunning(true) + + store := media.NewFileMediaStore() + ch.SetMediaStore(store) + + tmpDir := t.TempDir() + localPath := filepath.Join(tmpDir, "report.txt") + if err := os.WriteFile(localPath, []byte("attachment body"), 0o600); err != nil { + t.Fatalf("WriteFile() error = %v", err) + } + ref, err := store.Store(localPath, media.MediaMeta{ + Filename: "report.txt", + ContentType: "text/plain", + }, "test-scope") + if err != nil { + t.Fatalf("Store() error = %v", err) + } + + var uploaded []slackUploadRecord + var posted []string + ch.uploadFileFn = func(ctx context.Context, params slacksdk.UploadFileParameters) error { + uploaded = append(uploaded, slackUploadRecord{ + Channel: params.Channel, + Thread: params.ThreadTimestamp, + File: params.File, + Name: params.Filename, + Title: params.Title, + }) + return nil + } + ch.postTextFn = func(ctx context.Context, channelID, threadTS, text string) error { + posted = append(posted, channelID+"|"+threadTS+"|"+text) + return nil + } + + _, err = ch.SendMedia(context.Background(), bus.OutboundMediaMessage{ + ChatID: "C123456/1234567890.123456", + Parts: []bus.MediaPart{{ + Ref: ref, + Type: "file", + Filename: "report.txt", + ContentType: "text/plain", + Caption: "shared caption", + }}, + }) + if err != nil { + t.Fatalf("SendMedia() error = %v", err) + } + if len(uploaded) != 1 { + t.Fatalf("uploads = %v, want 1 upload", uploaded) + } + if uploaded[0].Title != "shared caption" { + t.Fatalf("upload title = %q, want shared caption", uploaded[0].Title) + } + if len(posted) != 1 || posted[0] != "C123456|1234567890.123456|shared caption" { + t.Fatalf("posted = %v, want fallback text in same thread", posted) + } +} + +type slackUploadRecord struct { + Channel string + Thread string + File string + Name string + Title string +} diff --git a/pkg/channels/telegram/telegram.go b/pkg/channels/telegram/telegram.go index 45672e5ee..8fe325b25 100644 --- a/pkg/channels/telegram/telegram.go +++ b/pkg/channels/telegram/telegram.go @@ -44,7 +44,10 @@ var ( reInlineCode = regexp.MustCompile("`([^`]+)`") ) -const defaultMediaGroupDelay = 500 * time.Millisecond +const ( + defaultMediaGroupDelay = 500 * time.Millisecond + telegramCaptionLimit = 1024 +) type TelegramChannel struct { *channels.BaseChannel @@ -639,6 +642,34 @@ func (c *TelegramChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMe } var messageIDs []string + leadingCaption := telegramLeadingCaption(msg.Parts) + if len([]rune(leadingCaption)) > telegramCaptionLimit { + leadingIDs, leadingErr := c.sendCaptionText(ctx, chatID, threadID, leadingCaption) + if leadingErr != nil { + return nil, leadingErr + } + messageIDs = append(messageIDs, leadingIDs...) + msg = telegramClearMediaCaptions(msg) + } + + if len(msg.Parts) > 1 && telegramCanSendMediaGroup(msg.Parts) { + groupIDs, err := c.sendImageMediaGroups(ctx, chatID, threadID, store, msg.Parts) + if err != nil { + logger.ErrorCF("telegram", "Failed to send media group", map[string]any{ + "count": len(msg.Parts), + "error": err.Error(), + }) + return nil, fmt.Errorf("telegram send media group: %w", channels.ErrTemporary) + } + if len(groupIDs) > 0 { + messageIDs = append(messageIDs, groupIDs...) + if hasTrackedMsg { + c.dismissTrackedToolFeedbackMessage(ctx, trackedChatID, trackedMsgID) + } + return messageIDs, nil + } + } + for _, part := range msg.Parts { localPath, err := store.Resolve(part.Ref) if err != nil { @@ -742,6 +773,154 @@ func (c *TelegramChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMe return messageIDs, nil } +func telegramCanSendMediaGroup(parts []bus.MediaPart) bool { + if len(parts) < 2 { + return false + } + for _, part := range parts { + if part.Type != "image" { + return false + } + } + return true +} + +func (c *TelegramChannel) sendImageMediaGroups( + ctx context.Context, + chatID int64, + threadID int, + store media.MediaStore, + parts []bus.MediaPart, +) ([]string, error) { + const maxGroupSize = 10 + + messageIDs := make([]string, 0, len(parts)) + for start := 0; start < len(parts); start += maxGroupSize { + end := start + maxGroupSize + if end > len(parts) { + end = len(parts) + } + groupIDs, err := c.sendSingleImageMediaGroup(ctx, chatID, threadID, store, parts[start:end]) + if err != nil { + return nil, err + } + messageIDs = append(messageIDs, groupIDs...) + } + return messageIDs, nil +} + +func (c *TelegramChannel) sendSingleImageMediaGroup( + ctx context.Context, + chatID int64, + threadID int, + store media.MediaStore, + parts []bus.MediaPart, +) ([]string, error) { + opened := make([]*os.File, 0, len(parts)) + defer func() { + for _, file := range opened { + file.Close() + } + }() + + inputMedia := make([]telego.InputMedia, 0, len(parts)) + for i, part := range parts { + localPath, err := store.Resolve(part.Ref) + if err != nil { + logger.ErrorCF("telegram", "Failed to resolve media ref for media group", map[string]any{ + "ref": part.Ref, + "error": err.Error(), + }) + return nil, err + } + + file, err := os.Open(localPath) + if err != nil { + logger.ErrorCF("telegram", "Failed to open media file for media group", map[string]any{ + "path": localPath, + "error": err.Error(), + }) + return nil, err + } + opened = append(opened, file) + + mediaItem := &telego.InputMediaPhoto{ + Type: telego.MediaTypePhoto, + Media: telego.InputFile{File: file}, + } + if i == 0 { + mediaItem.Caption = part.Caption + } + inputMedia = append(inputMedia, mediaItem) + } + + results, err := c.bot.SendMediaGroup(ctx, &telego.SendMediaGroupParams{ + ChatID: tu.ID(chatID), + MessageThreadID: threadID, + Media: inputMedia, + }) + if err != nil { + return nil, err + } + + messageIDs := make([]string, 0, len(results)) + for _, result := range results { + messageIDs = append(messageIDs, strconv.Itoa(result.MessageID)) + } + return messageIDs, nil +} + +func (c *TelegramChannel) sendCaptionText( + ctx context.Context, + chatID int64, + threadID int, + text string, +) ([]string, error) { + text = strings.TrimSpace(text) + if text == "" { + return nil, nil + } + chunks := channels.SplitMessage(text, c.MaxMessageLength()) + messageIDs := make([]string, 0, len(chunks)) + for _, chunk := range chunks { + chunk = strings.TrimSpace(chunk) + if chunk == "" { + continue + } + msgID, err := c.sendChunk(ctx, sendChunkParams{ + chatID: chatID, + threadID: threadID, + content: chunk, + mdFallback: chunk, + useMarkdownV2: false, + }) + if err != nil { + return nil, err + } + messageIDs = append(messageIDs, msgID) + } + return messageIDs, nil +} + +func telegramLeadingCaption(parts []bus.MediaPart) string { + if len(parts) == 0 { + return "" + } + return strings.TrimSpace(parts[0].Caption) +} + +func telegramClearMediaCaptions(msg bus.OutboundMediaMessage) bus.OutboundMediaMessage { + if len(msg.Parts) == 0 { + return msg + } + cloned := msg + cloned.Parts = append([]bus.MediaPart(nil), msg.Parts...) + for i := range cloned.Parts { + cloned.Parts[i].Caption = "" + } + return cloned +} + func (c *TelegramChannel) handleMessage(ctx context.Context, message *telego.Message) error { if message != nil && strings.TrimSpace(message.MediaGroupID) != "" { return c.bufferMediaGroupMessage(ctx, message) diff --git a/pkg/channels/telegram/telegram_test.go b/pkg/channels/telegram/telegram_test.go index 0ebde1328..b52f2c9b2 100644 --- a/pkg/channels/telegram/telegram_test.go +++ b/pkg/channels/telegram/telegram_test.go @@ -110,6 +110,17 @@ func successResponseWithMessageID(t *testing.T, messageID int) *ta.Response { return &ta.Response{Ok: true, Result: b} } +func successMediaGroupResponse(t *testing.T, messageIDs ...int) *ta.Response { + t.Helper() + messages := make([]telego.Message, 0, len(messageIDs)) + for _, messageID := range messageIDs { + messages = append(messages, telego.Message{MessageID: messageID}) + } + b, err := json.Marshal(messages) + require.NoError(t, err) + return &ta.Response{Ok: true, Result: b} +} + func successUserResponse(t *testing.T, user *telego.User) *ta.Response { t.Helper() b, err := json.Marshal(user) @@ -237,6 +248,276 @@ func TestSendMedia_ImageNonDimensionErrorDoesNotFallback(t *testing.T) { assert.NotContains(t, caller.calls[0].URL, "sendDocument") } +func TestSendMedia_MultipleImagesUseMediaGroup(t *testing.T) { + constructor := &multipartRecordingConstructor{} + caller := &stubCaller{ + callFn: func(ctx context.Context, url string, data *ta.RequestData) (*ta.Response, error) { + if strings.Contains(url, "sendMediaGroup") { + return successMediaGroupResponse(t, 101, 102), nil + } + t.Fatalf("unexpected API call: %s", url) + return nil, nil + }, + } + ch := newTestChannelWithConstructor(t, caller, constructor) + + store := media.NewFileMediaStore() + ch.SetMediaStore(store) + + tmpDir := t.TempDir() + firstPath := filepath.Join(tmpDir, "first.png") + secondPath := filepath.Join(tmpDir, "second.png") + require.NoError(t, os.WriteFile(firstPath, []byte("first-image"), 0o644)) + require.NoError(t, os.WriteFile(secondPath, []byte("second-image"), 0o644)) + + firstRef, err := store.Store(firstPath, media.MediaMeta{Filename: "first.png", ContentType: "image/png"}, "scope-1") + require.NoError(t, err) + secondRef, err := store.Store( + secondPath, + media.MediaMeta{Filename: "second.png", ContentType: "image/png"}, + "scope-1", + ) + require.NoError(t, err) + + ids, err := ch.SendMedia(context.Background(), bus.OutboundMediaMessage{ + ChatID: "12345", + Parts: []bus.MediaPart{ + {Type: "image", Ref: firstRef, Caption: "album caption"}, + {Type: "image", Ref: secondRef}, + }, + }) + + require.NoError(t, err) + assert.Equal(t, []string{"101", "102"}, ids) + require.Len(t, caller.calls, 1) + assert.Contains(t, caller.calls[0].URL, "sendMediaGroup") + require.Len(t, constructor.calls, 1) + require.Len(t, constructor.calls[0].FileSizes, 2) + + var mediaPayload []map[string]any + require.NoError(t, json.Unmarshal([]byte(constructor.calls[0].Parameters["media"]), &mediaPayload)) + require.Len(t, mediaPayload, 2) + assert.Equal(t, "album caption", mediaPayload[0]["caption"]) + _, hasSecondCaption := mediaPayload[1]["caption"] + assert.False(t, hasSecondCaption) +} + +func TestSendMedia_MoreThanTenImagesSplitIntoMediaGroups(t *testing.T) { + constructor := &multipartRecordingConstructor{} + callIndex := 0 + caller := &stubCaller{ + callFn: func(ctx context.Context, url string, data *ta.RequestData) (*ta.Response, error) { + if !strings.Contains(url, "sendMediaGroup") { + t.Fatalf("unexpected API call: %s", url) + } + callIndex++ + if callIndex == 1 { + return successMediaGroupResponse(t, 1001, 1002, 1003, 1004, 1005, 1006, 1007, 1008, 1009, 1010), nil + } + if callIndex == 2 { + return successMediaGroupResponse(t, 1011, 1012, 1013, 1014, 1015), nil + } + t.Fatalf("unexpected sendMediaGroup call #%d", callIndex) + return nil, nil + }, + } + ch := newTestChannelWithConstructor(t, caller, constructor) + + store := media.NewFileMediaStore() + ch.SetMediaStore(store) + + tmpDir := t.TempDir() + parts := make([]bus.MediaPart, 0, 15) + for i := 0; i < 15; i++ { + path := filepath.Join(tmpDir, "image-"+strconv.Itoa(i)+".png") + require.NoError(t, os.WriteFile(path, []byte("img-"+strconv.Itoa(i)), 0o644)) + ref, err := store.Store( + path, + media.MediaMeta{Filename: filepath.Base(path), ContentType: "image/png"}, + "scope-1", + ) + require.NoError(t, err) + part := bus.MediaPart{Type: "image", Ref: ref} + if i == 0 { + part.Caption = "long album caption" + } + parts = append(parts, part) + } + + ids, err := ch.SendMedia(context.Background(), bus.OutboundMediaMessage{ + ChatID: "12345", + Parts: parts, + }) + + require.NoError(t, err) + assert.Equal(t, []string{ + "1001", "1002", "1003", "1004", "1005", + "1006", "1007", "1008", "1009", "1010", + "1011", "1012", "1013", "1014", "1015", + }, ids) + require.Len(t, caller.calls, 2) + require.Len(t, constructor.calls, 2) +} + +func TestSendMedia_SingleImageLongCaptionSendsTextFirst(t *testing.T) { + constructor := &multipartRecordingConstructor{} + longCaption := strings.Repeat("a", telegramCaptionLimit) + " tail overflow" + caller := &stubCaller{ + callFn: func(ctx context.Context, url string, data *ta.RequestData) (*ta.Response, error) { + switch { + case strings.Contains(url, "sendMessage"): + return successResponseWithMessageID(t, 201), nil + case strings.Contains(url, "sendPhoto"): + return successResponseWithMessageID(t, 202), nil + default: + t.Fatalf("unexpected API call: %s", url) + return nil, nil + } + }, + } + ch := newTestChannelWithConstructor(t, caller, constructor) + + store := media.NewFileMediaStore() + ch.SetMediaStore(store) + + tmpDir := t.TempDir() + path := filepath.Join(tmpDir, "image.png") + require.NoError(t, os.WriteFile(path, []byte("img"), 0o644)) + ref, err := store.Store(path, media.MediaMeta{Filename: "image.png", ContentType: "image/png"}, "scope-1") + require.NoError(t, err) + + ids, err := ch.SendMedia(context.Background(), bus.OutboundMediaMessage{ + ChatID: "12345", + Parts: []bus.MediaPart{{ + Type: "image", + Ref: ref, + Caption: longCaption, + }}, + }) + + require.NoError(t, err) + assert.Equal(t, []string{"201", "202"}, ids) + require.Len(t, caller.calls, 2) + assert.Contains(t, caller.calls[0].URL, "sendMessage") + assert.Contains(t, caller.calls[1].URL, "sendPhoto") + assert.Equal(t, "", constructor.calls[0].Parameters["caption"]) +} + +func TestSendMedia_MediaGroupLongCaptionSendsTextFirst(t *testing.T) { + constructor := &multipartRecordingConstructor{} + longCaption := strings.Repeat("b", telegramCaptionLimit) + " trailing explanation" + caller := &stubCaller{ + callFn: func(ctx context.Context, url string, data *ta.RequestData) (*ta.Response, error) { + switch { + case strings.Contains(url, "sendMessage"): + return successResponseWithMessageID(t, 301), nil + case strings.Contains(url, "sendMediaGroup"): + return successMediaGroupResponse(t, 302, 303), nil + default: + t.Fatalf("unexpected API call: %s", url) + return nil, nil + } + }, + } + ch := newTestChannelWithConstructor(t, caller, constructor) + + store := media.NewFileMediaStore() + ch.SetMediaStore(store) + + tmpDir := t.TempDir() + firstPath := filepath.Join(tmpDir, "first.png") + secondPath := filepath.Join(tmpDir, "second.png") + require.NoError(t, os.WriteFile(firstPath, []byte("first-image"), 0o644)) + require.NoError(t, os.WriteFile(secondPath, []byte("second-image"), 0o644)) + + firstRef, err := store.Store(firstPath, media.MediaMeta{Filename: "first.png", ContentType: "image/png"}, "scope-1") + require.NoError(t, err) + secondRef, err := store.Store( + secondPath, + media.MediaMeta{Filename: "second.png", ContentType: "image/png"}, + "scope-1", + ) + require.NoError(t, err) + + ids, err := ch.SendMedia(context.Background(), bus.OutboundMediaMessage{ + ChatID: "12345", + Parts: []bus.MediaPart{ + {Type: "image", Ref: firstRef, Caption: longCaption}, + {Type: "image", Ref: secondRef}, + }, + }) + + require.NoError(t, err) + assert.Equal(t, []string{"301", "302", "303"}, ids) + require.Len(t, caller.calls, 2) + assert.Contains(t, caller.calls[0].URL, "sendMessage") + assert.Contains(t, caller.calls[1].URL, "sendMediaGroup") +} + +func TestSendMedia_MultiGroupLongCaptionSendsTextBeforeGroups(t *testing.T) { + constructor := &multipartRecordingConstructor{} + longCaption := strings.Repeat("c", telegramCaptionLimit) + " overflow before second album" + callOrder := make([]string, 0, 3) + caller := &stubCaller{ + callFn: func(ctx context.Context, url string, data *ta.RequestData) (*ta.Response, error) { + switch { + case strings.Contains(url, "sendMessage"): + callOrder = append(callOrder, "text") + return successResponseWithMessageID(t, 499), nil + case strings.Contains(url, "sendMediaGroup"): + callOrder = append(callOrder, "group") + if len(callOrder) == 2 { + return successMediaGroupResponse(t, 401, 402, 403, 404, 405, 406, 407, 408, 409, 410), nil + } + if len(callOrder) == 3 { + return successMediaGroupResponse(t, 411, 412, 413, 414, 415), nil + } + t.Fatalf("unexpected sendMediaGroup order: %v", callOrder) + return nil, nil + default: + t.Fatalf("unexpected API call: %s", url) + return nil, nil + } + }, + } + ch := newTestChannelWithConstructor(t, caller, constructor) + + store := media.NewFileMediaStore() + ch.SetMediaStore(store) + + tmpDir := t.TempDir() + parts := make([]bus.MediaPart, 0, 15) + for i := 0; i < 15; i++ { + path := filepath.Join(tmpDir, "image-"+strconv.Itoa(i)+".png") + require.NoError(t, os.WriteFile(path, []byte("img-"+strconv.Itoa(i)), 0o644)) + ref, err := store.Store( + path, + media.MediaMeta{Filename: filepath.Base(path), ContentType: "image/png"}, + "scope-1", + ) + require.NoError(t, err) + part := bus.MediaPart{Type: "image", Ref: ref} + if i == 0 { + part.Caption = longCaption + } + parts = append(parts, part) + } + + ids, err := ch.SendMedia(context.Background(), bus.OutboundMediaMessage{ + ChatID: "12345", + Parts: parts, + }) + + require.NoError(t, err) + assert.Equal(t, []string{ + "499", + "401", "402", "403", "404", "405", + "406", "407", "408", "409", "410", + "411", "412", "413", "414", "415", + }, ids) + assert.Equal(t, []string{"text", "group", "group"}, callOrder) +} + func TestSend_EmptyContent(t *testing.T) { caller := &stubCaller{ callFn: func(ctx context.Context, url string, data *ta.RequestData) (*ta.Response, error) { diff --git a/pkg/channels/weixin/weixin_test.go b/pkg/channels/weixin/weixin_test.go index aea2cbb0c..587c35a8e 100644 --- a/pkg/channels/weixin/weixin_test.go +++ b/pkg/channels/weixin/weixin_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "encoding/base64" + "encoding/json" "errors" "io" "net/http" @@ -319,3 +320,61 @@ func TestSelectInboundMediaItemFallsBackToRefMessage(t *testing.T) { t.Fatalf("selectInboundMediaItem().Type = %d, want %d", item.Type, MessageItemTypeImage) } } + +func TestSendUploadedMedia_SendsCaptionAsSeparateTextBeforeMedia(t *testing.T) { + var requests []SendMessageReq + ch := &WeixinChannel{ + api: &ApiClient{ + BaseURL: "https://ilinkai.weixin.qq.com/", + HttpClient: &http.Client{Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) { + if r.URL.Path != "/ilink/bot/sendmessage" { + t.Fatalf("sendmessage path = %q, want /ilink/bot/sendmessage", r.URL.Path) + } + var req SendMessageReq + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + t.Fatalf("decode sendmessage req: %v", err) + } + requests = append(requests, req) + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewReader([]byte(`{"ret":0,"errcode":0}`))), + Header: make(http.Header), + }, nil + })}, + }, + typingCache: make(map[string]typingTicketCacheEntry), + } + + err := ch.sendUploadedMedia( + context.Background(), + "user-1", + "ctx-1", + "recipe translation", + UploadMediaTypeImage, + &uploadedFileInfo{ + downloadParam: "download-token", + aesKeyHex: "31323334353637383930616263646566", + fileSize: 11, + cipherSize: 16, + filename: "photo.png", + }, + ) + if err != nil { + t.Fatalf("sendUploadedMedia() error = %v", err) + } + if len(requests) != 2 { + t.Fatalf("sendUploadedMedia() sent %d requests, want 2", len(requests)) + } + if len(requests[0].Msg.ItemList) != 1 || requests[0].Msg.ItemList[0].Type != MessageItemTypeText { + t.Fatalf("first request item = %+v, want text item", requests[0].Msg.ItemList) + } + if got := requests[0].Msg.ItemList[0].TextItem.Text; got != "recipe translation" { + t.Fatalf("first request text = %q, want recipe translation", got) + } + if len(requests[1].Msg.ItemList) != 1 || requests[1].Msg.ItemList[0].Type != MessageItemTypeImage { + t.Fatalf("second request item = %+v, want image item", requests[1].Msg.ItemList) + } + if requests[1].Msg.ItemList[0].ImageItem == nil || requests[1].Msg.ItemList[0].ImageItem.Media == nil { + t.Fatalf("second request image media = %+v, want media ref", requests[1].Msg.ItemList[0].ImageItem) + } +} diff --git a/pkg/config/config.go b/pkg/config/config.go index d9608d11e..b36014b9f 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -814,6 +814,12 @@ type ToolConfig struct { Enabled bool `json:"enabled" yaml:"-" env:"ENABLED"` } +type MessageToolsConfig struct { + ToolConfig `yaml:"-" envPrefix:"PICOCLAW_TOOLS_MESSAGE_"` + + MediaEnabled bool `json:"media_enabled" yaml:"-" env:"PICOCLAW_TOOLS_MESSAGE_MEDIA_ENABLED"` +} + type BraveConfig struct { Enabled bool `json:"enabled" yaml:"-" env:"PICOCLAW_TOOLS_WEB_BRAVE_ENABLED"` APIKeys SecureStrings `json:"api_keys,omitzero" yaml:"api_keys,omitempty" env:"PICOCLAW_TOOLS_WEB_BRAVE_API_KEYS"` @@ -1026,7 +1032,7 @@ type ToolsConfig struct { InstallSkill ToolConfig `json:"install_skill" yaml:"-" envPrefix:"PICOCLAW_TOOLS_INSTALL_SKILL_"` ListDir ToolConfig `json:"list_dir" yaml:"-" envPrefix:"PICOCLAW_TOOLS_LIST_DIR_"` LoadImage ToolConfig `json:"load_image" yaml:"-" envPrefix:"PICOCLAW_TOOLS_LOAD_IMAGE_"` - Message ToolConfig `json:"message" yaml:"-" envPrefix:"PICOCLAW_TOOLS_MESSAGE_"` + Message MessageToolsConfig `json:"message" yaml:"-"` ReadFile ReadFileToolConfig `json:"read_file" yaml:"-" envPrefix:"PICOCLAW_TOOLS_READ_FILE_"` Serial ToolConfig `json:"serial" yaml:"-" envPrefix:"PICOCLAW_TOOLS_SERIAL_"` SendFile ToolConfig `json:"send_file" yaml:"-" envPrefix:"PICOCLAW_TOOLS_SEND_FILE_"` diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go index e34f23895..213090a15 100644 --- a/pkg/config/config_test.go +++ b/pkg/config/config_test.go @@ -1480,6 +1480,16 @@ func TestLoadConfig_LoadImageCanBeDisabled(t *testing.T) { } } +func TestDefaultConfig_MessageMediaDisabled(t *testing.T) { + cfg := DefaultConfig() + if !cfg.Tools.Message.Enabled { + t.Fatal("DefaultConfig().Tools.Message.Enabled should be true") + } + if cfg.Tools.Message.MediaEnabled { + t.Fatal("DefaultConfig().Tools.Message.MediaEnabled should be false") + } +} + func TestToolsConfig_GetFilterMinLength(t *testing.T) { tests := []struct { name string diff --git a/pkg/config/defaults.go b/pkg/config/defaults.go index c411aadf3..d7bd16875 100644 --- a/pkg/config/defaults.go +++ b/pkg/config/defaults.go @@ -447,8 +447,11 @@ func DefaultConfig() *Config { LoadImage: ToolConfig{ Enabled: true, }, - Message: ToolConfig{ - Enabled: true, + Message: MessageToolsConfig{ + ToolConfig: ToolConfig{ + Enabled: true, + }, + MediaEnabled: false, }, ReadFile: ReadFileToolConfig{ Enabled: true, diff --git a/pkg/providers/oauth/codex_provider.go b/pkg/providers/oauth/codex_provider.go index 0b125997b..b0d7bd758 100644 --- a/pkg/providers/oauth/codex_provider.go +++ b/pkg/providers/oauth/codex_provider.go @@ -104,8 +104,12 @@ func (p *CodexProvider) Chat( defer stream.Close() var resp *responses.Response + var streamedText strings.Builder for stream.Next() { evt := stream.Current() + if evt.Type == "response.output_text.delta" { + streamedText.WriteString(evt.Delta) + } if evt.Type == "response.completed" || evt.Type == "response.failed" || evt.Type == "response.incomplete" { evtResp := evt.Response if evtResp.ID != "" { @@ -153,7 +157,11 @@ func (p *CodexProvider) Chat( return nil, fmt.Errorf("codex API call: stream ended without completed response") } - return orc.ParseResponseFromStruct(resp), nil + parsed := orc.ParseResponseFromStruct(resp) + if parsed.Content == "" && streamedText.Len() > 0 { + parsed.Content = streamedText.String() + } + return parsed, nil } func (p *CodexProvider) GetDefaultModel() string { diff --git a/pkg/providers/oauth/codex_provider_test.go b/pkg/providers/oauth/codex_provider_test.go index aeeb18360..8deeb8d2a 100644 --- a/pkg/providers/oauth/codex_provider_test.go +++ b/pkg/providers/oauth/codex_provider_test.go @@ -374,6 +374,51 @@ func TestCodexProvider_ChatRoundTrip(t *testing.T) { } } +func TestCodexProvider_ChatRoundTrip_OutputTextDeltaFallback(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/responses" { + http.Error(w, "not found: "+r.URL.Path, http.StatusNotFound) + return + } + + var reqBody map[string]any + if err := json.NewDecoder(r.Body).Decode(&reqBody); err != nil { + http.Error(w, "invalid json", http.StatusBadRequest) + return + } + if reqBody["stream"] != true { + http.Error(w, "stream must be true", http.StatusBadRequest) + return + } + + resp := map[string]any{ + "id": "resp_test", + "object": "response", + "status": "completed", + "output": nil, + } + writeOutputTextDeltaSSE(w, "OK", resp) + })) + defer server.Close() + + provider := NewCodexProvider("test-token", "acc-123") + provider.client = createOpenAITestClient(server.URL, "test-token", "acc-123") + + resp, err := provider.Chat( + t.Context(), + []Message{{Role: "user", Content: "Hello"}}, + nil, + "gpt-4o", + map[string]any{}, + ) + if err != nil { + t.Fatalf("Chat() error: %v", err) + } + if resp.Content != "OK" { + t.Errorf("Content = %q, want %q", resp.Content, "OK") + } +} + func TestCodexProvider_ChatRoundTrip_WebSearchDisabled(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/responses" { @@ -647,3 +692,24 @@ func writeCompletedSSE(w http.ResponseWriter, response map[string]any) { fmt.Fprintf(w, "data: %s\n\n", string(b)) fmt.Fprintf(w, "data: [DONE]\n\n") } + +func writeOutputTextDeltaSSE(w http.ResponseWriter, delta string, response map[string]any) { + deltaEvent := map[string]any{ + "type": "response.output_text.delta", + "sequence_number": 1, + "delta": delta, + } + completedEvent := map[string]any{ + "type": "response.completed", + "sequence_number": 2, + "response": response, + } + deltaBytes, _ := json.Marshal(deltaEvent) + completedBytes, _ := json.Marshal(completedEvent) + w.Header().Set("Content-Type", "text/event-stream") + fmt.Fprintf(w, "event: response.output_text.delta\n") + fmt.Fprintf(w, "data: %s\n\n", string(deltaBytes)) + fmt.Fprintf(w, "event: response.completed\n") + fmt.Fprintf(w, "data: %s\n\n", string(completedBytes)) + fmt.Fprintf(w, "data: [DONE]\n\n") +} diff --git a/pkg/tools/integration/message.go b/pkg/tools/integration/message.go index 98d87bcb3..fbd8305c6 100644 --- a/pkg/tools/integration/message.go +++ b/pkg/tools/integration/message.go @@ -3,10 +3,32 @@ package integrationtools import ( "context" "fmt" + "mime" + "os" + "path/filepath" + "regexp" + "strings" "sync" + + "github.com/h2non/filetype" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/media" + fstools "github.com/sipeed/picoclaw/pkg/tools/fs" ) -type SendCallbackWithContext func(ctx context.Context, channel, chatID, content, replyToMessageID string) error +type SendCallbackWithContext func( + ctx context.Context, + channel, chatID, content, replyToMessageID string, + mediaParts []bus.MediaPart, +) error + +type messageMediaArg struct { + Path string + Type string + Filename string +} // sentTarget records the channel+chatID that the message tool sent to. type sentTarget struct { @@ -15,11 +37,15 @@ type sentTarget struct { } type MessageTool struct { - sendCallback SendCallbackWithContext - mu sync.Mutex - // sentTargets tracks targets sent to in the current round, keyed by session key - // to support parallel turns for different sessions. - sentTargets map[string][]sentTarget + sendCallback SendCallbackWithContext + workspace string + restrict bool + maxFileSize int + mediaStore media.MediaStore + allowPaths []*regexp.Regexp + localMediaEnabled bool + mu sync.Mutex + sentTargets map[string][]sentTarget } func NewMessageTool() *MessageTool { @@ -33,32 +59,86 @@ func (t *MessageTool) Name() string { } func (t *MessageTool) Description() string { - return "Send a message to user on a chat channel. Use this when you want to communicate something." + if !t.localMediaEnabled { + return "Send a text message to the user on a chat channel." + } + return "Send a message to the user on a chat channel. Supports text-only, media-only, or text with media attachments." } func (t *MessageTool) Parameters() map[string]any { - return map[string]any{ - "type": "object", - "properties": map[string]any{ - "content": map[string]any{ - "type": "string", - "description": "The message content to send", - }, - "channel": map[string]any{ - "type": "string", - "description": "Optional: target channel (telegram, whatsapp, etc.)", - }, - "chat_id": map[string]any{ - "type": "string", - "description": "Optional: target chat/user ID", - }, - "reply_to_message_id": map[string]any{ - "type": "string", - "description": "Optional: reply target message ID for channels that support threaded replies", - }, + properties := map[string]any{ + "content": map[string]any{ + "type": "string", + "description": "Optional message text. When media is present, this text is used as the caption/body for the media message.", + }, + "channel": map[string]any{ + "type": "string", + "description": "Optional: target channel (telegram, whatsapp, etc.)", + }, + "chat_id": map[string]any{ + "type": "string", + "description": "Optional: target chat/user ID", + }, + "reply_to_message_id": map[string]any{ + "type": "string", + "description": "Optional: reply target message ID for channels that support threaded replies", }, - "required": []string{"content"}, } + params := map[string]any{ + "type": "object", + "properties": properties, + "required": []string{"content"}, + } + if t.localMediaEnabled { + properties["media"] = map[string]any{ + "type": "array", + "description": "Optional local media attachments to send with the message. Requires tools.message.media_enabled.", + "items": map[string]any{ + "type": "object", + "properties": map[string]any{ + "path": map[string]any{ + "type": "string", + "description": "Path to the local file. Relative paths are resolved from workspace.", + }, + "type": map[string]any{ + "type": "string", + "description": "Optional media type hint: image, audio, video, or file.", + }, + "filename": map[string]any{ + "type": "string", + "description": "Optional display filename. Defaults to the basename of path.", + }, + }, + "required": []string{"path"}, + }, + } + delete(params, "required") + params["anyOf"] = []map[string]any{ + {"required": []string{"content"}}, + {"required": []string{"media"}}, + } + } + return params +} + +func (t *MessageTool) ConfigureLocalMedia( + workspace string, + restrict bool, + maxFileSize int, + allowPaths []*regexp.Regexp, +) { + t.workspace = workspace + t.restrict = restrict + if maxFileSize <= 0 { + maxFileSize = config.DefaultMaxMediaSize + } + t.maxFileSize = maxFileSize + t.allowPaths = allowPaths + t.localMediaEnabled = true +} + +func (t *MessageTool) SetMediaStore(store media.MediaStore) { + t.mediaStore = store } // ResetSentInRound resets the per-round send tracker for the given session key. @@ -98,9 +178,20 @@ func (t *MessageTool) SetSendCallback(callback SendCallbackWithContext) { } func (t *MessageTool) Execute(ctx context.Context, args map[string]any) *ToolResult { - content, ok := args["content"].(string) - if !ok { - return &ToolResult{ForLLM: "content is required", IsError: true} + content, _ := args["content"].(string) + content = strings.TrimSpace(content) + mediaArgs, err := parseMessageMediaArgs(args["media"]) + if err != nil { + return &ToolResult{ForLLM: err.Error(), IsError: true} + } + if len(mediaArgs) > 0 && !t.localMediaEnabled { + return &ToolResult{ + ForLLM: "message media attachments are disabled; enable tools.message.media_enabled to send local media through message", + IsError: true, + } + } + if content == "" && len(mediaArgs) == 0 { + return &ToolResult{ForLLM: "content or media is required", IsError: true} } channel, _ := args["channel"].(string) @@ -122,7 +213,12 @@ func (t *MessageTool) Execute(ctx context.Context, args map[string]any) *ToolRes return &ToolResult{ForLLM: "Message sending not configured", IsError: true} } - if err := t.sendCallback(ctx, channel, chatID, content, replyToMessageID); err != nil { + parts, err := t.buildMediaParts(channel, chatID, content, mediaArgs) + if err != nil { + return &ToolResult{ForLLM: err.Error(), IsError: true, Err: err} + } + + if err := t.sendCallback(ctx, channel, chatID, content, replyToMessageID, parts); err != nil { return &ToolResult{ ForLLM: fmt.Sprintf("sending message: %v", err), IsError: true, @@ -135,9 +231,149 @@ func (t *MessageTool) Execute(ctx context.Context, args map[string]any) *ToolRes t.sentTargets[sessionKey] = append(t.sentTargets[sessionKey], sentTarget{Channel: channel, ChatID: chatID}) t.mu.Unlock() - // Silent: user already received the message directly + status := fmt.Sprintf("Message sent to %s:%s", channel, chatID) + if len(parts) > 0 { + status = fmt.Sprintf("Message with %d media attachment(s) sent to %s:%s", len(parts), channel, chatID) + } + return &ToolResult{ - ForLLM: fmt.Sprintf("Message sent to %s:%s", channel, chatID), + ForLLM: status, Silent: true, } } + +func parseMessageMediaArgs(raw any) ([]messageMediaArg, error) { + if raw == nil { + return nil, nil + } + items, ok := raw.([]any) + if !ok { + return nil, fmt.Errorf("media must be an array") + } + result := make([]messageMediaArg, 0, len(items)) + for i, item := range items { + obj, ok := item.(map[string]any) + if !ok { + return nil, fmt.Errorf("media[%d] must be an object", i) + } + path, _ := obj["path"].(string) + path = strings.TrimSpace(path) + if path == "" { + return nil, fmt.Errorf("media[%d].path is required", i) + } + typ, _ := obj["type"].(string) + filename, _ := obj["filename"].(string) + result = append(result, messageMediaArg{ + Path: path, + Type: strings.TrimSpace(typ), + Filename: strings.TrimSpace(filename), + }) + } + return result, nil +} + +func (t *MessageTool) buildMediaParts( + channel, chatID, content string, + mediaArgs []messageMediaArg, +) ([]bus.MediaPart, error) { + if len(mediaArgs) == 0 { + return nil, nil + } + if !t.localMediaEnabled { + return nil, fmt.Errorf("message media attachments are disabled") + } + if t.mediaStore == nil { + return nil, fmt.Errorf("media store not configured") + } + if strings.TrimSpace(t.workspace) == "" { + return nil, fmt.Errorf("message media delivery is not configured") + } + + scope := fmt.Sprintf("tool:message:%s:%s", channel, chatID) + parts := make([]bus.MediaPart, 0, len(mediaArgs)) + for i, item := range mediaArgs { + resolved, err := fstools.ValidatePathWithAllowPaths(item.Path, t.workspace, t.restrict, t.allowPaths) + if err != nil { + return nil, fmt.Errorf("invalid media[%d].path: %w", i, err) + } + info, err := os.Stat(resolved) + if err != nil { + return nil, fmt.Errorf("media[%d] file not found: %w", i, err) + } + if info.IsDir() { + return nil, fmt.Errorf("media[%d] path is a directory, expected a file", i) + } + if t.maxFileSize > 0 && info.Size() > int64(t.maxFileSize) { + return nil, fmt.Errorf("media[%d] file too large: %d bytes (max %d bytes)", i, info.Size(), t.maxFileSize) + } + + filename := item.Filename + if filename == "" { + filename = filepath.Base(resolved) + } + contentType := detectMessageMediaType(resolved) + partType := normalizeMessageMediaType(item.Type, filename, contentType) + ref, err := t.mediaStore.Store(resolved, media.MediaMeta{ + Filename: filename, + ContentType: contentType, + Source: "tool:message", + CleanupPolicy: media.CleanupPolicyForgetOnly, + }, scope) + if err != nil { + return nil, fmt.Errorf("failed to register media[%d]: %w", i, err) + } + + part := bus.MediaPart{ + Type: partType, + Ref: ref, + Filename: filename, + ContentType: contentType, + } + if i == 0 && content != "" { + part.Caption = content + } + parts = append(parts, part) + } + return parts, nil +} + +func detectMessageMediaType(path string) string { + kind, err := filetype.MatchFile(path) + if err == nil && kind != filetype.Unknown { + return kind.MIME.Value + } + if ext := filepath.Ext(path); ext != "" { + if t := mime.TypeByExtension(ext); t != "" { + return t + } + } + return "application/octet-stream" +} + +func normalizeMessageMediaType(typeHint, filename, contentType string) string { + switch strings.ToLower(strings.TrimSpace(typeHint)) { + case "image", "audio", "video", "file": + return strings.ToLower(strings.TrimSpace(typeHint)) + } + + ct := strings.ToLower(strings.TrimSpace(contentType)) + switch { + case strings.HasPrefix(ct, "image/"): + return "image" + case strings.HasPrefix(ct, "audio/"): + return "audio" + case strings.HasPrefix(ct, "video/"): + return "video" + } + + switch strings.ToLower(filepath.Ext(filename)) { + case ".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp": + return "image" + case ".mp3", ".wav", ".ogg", ".oga", ".m4a", ".flac": + return "audio" + case ".mp4", ".mov", ".mkv", ".webm", ".avi": + return "video" + default: + return "file" + } +} diff --git a/pkg/tools/integration/message_test.go b/pkg/tools/integration/message_test.go index c7b7d2b6e..eea345c1c 100644 --- a/pkg/tools/integration/message_test.go +++ b/pkg/tools/integration/message_test.go @@ -3,8 +3,13 @@ package integrationtools import ( "context" "errors" + "os" + "path/filepath" + "regexp" "testing" + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/media" "github.com/sipeed/picoclaw/pkg/session" ) @@ -12,10 +17,17 @@ func TestMessageTool_Execute_Success(t *testing.T) { tool := NewMessageTool() var sentChannel, sentChatID, sentContent string - tool.SetSendCallback(func(ctx context.Context, channel, chatID, content, replyToMessageID string) error { + tool.SetSendCallback(func( + ctx context.Context, + channel, chatID, content, replyToMessageID string, + mediaParts []bus.MediaPart, + ) error { sentChannel = channel sentChatID = chatID sentContent = content + if len(mediaParts) != 0 { + t.Fatalf("expected no media parts, got %d", len(mediaParts)) + } if ToolAgentID(ctx) != "" || ToolSessionKey(ctx) != "" || ToolSessionScope(ctx) != nil { t.Fatalf("expected empty turn metadata in basic context, got agent=%q session=%q scope=%+v", ToolAgentID(ctx), ToolSessionKey(ctx), ToolSessionScope(ctx)) @@ -67,7 +79,11 @@ func TestMessageTool_Execute_WithCustomChannel(t *testing.T) { tool := NewMessageTool() var sentChannel, sentChatID string - tool.SetSendCallback(func(ctx context.Context, channel, chatID, content, replyToMessageID string) error { + tool.SetSendCallback(func( + ctx context.Context, + channel, chatID, content, replyToMessageID string, + mediaParts []bus.MediaPart, + ) error { sentChannel = channel sentChatID = chatID return nil @@ -102,7 +118,11 @@ func TestMessageTool_Execute_SendFailure(t *testing.T) { tool := NewMessageTool() sendErr := errors.New("network error") - tool.SetSendCallback(func(ctx context.Context, channel, chatID, content, replyToMessageID string) error { + tool.SetSendCallback(func( + ctx context.Context, + channel, chatID, content, replyToMessageID string, + mediaParts []bus.MediaPart, + ) error { return sendErr }) @@ -142,12 +162,12 @@ func TestMessageTool_Execute_MissingContent(t *testing.T) { result := tool.Execute(ctx, args) - // Verify error result for missing content + // Verify error result for missing content/media if !result.IsError { - t.Error("Expected IsError=true for missing content") + t.Error("Expected IsError=true for missing content/media") } - if result.ForLLM != "content is required" { - t.Errorf("Expected ForLLM 'content is required', got '%s'", result.ForLLM) + if result.ForLLM != "content or media is required" { + t.Errorf("Expected ForLLM 'content or media is required', got '%s'", result.ForLLM) } } @@ -155,7 +175,11 @@ func TestMessageTool_Execute_NoTargetChannel(t *testing.T) { tool := NewMessageTool() // No WithToolContext — channel/chatID are empty - tool.SetSendCallback(func(ctx context.Context, channel, chatID, content, replyToMessageID string) error { + tool.SetSendCallback(func( + ctx context.Context, + channel, chatID, content, replyToMessageID string, + mediaParts []bus.MediaPart, + ) error { return nil }) @@ -228,7 +252,7 @@ func TestMessageTool_Parameters(t *testing.T) { // Check required properties required, ok := params["required"].([]string) if !ok || len(required) != 1 || required[0] != "content" { - t.Error("Expected 'content' to be required") + t.Fatal("Expected content-only required schema when local media is disabled") } // Check content property @@ -240,6 +264,10 @@ func TestMessageTool_Parameters(t *testing.T) { t.Error("Expected content type to be 'string'") } + if _, hasMedia := props["media"]; hasMedia { + t.Fatal("did not expect 'media' property when local media is disabled") + } + // Check channel property (optional) channelProp, ok := props["channel"].(map[string]any) if !ok { @@ -268,11 +296,65 @@ func TestMessageTool_Parameters(t *testing.T) { } } +func TestMessageTool_Parameters_WithLocalMediaEnabled(t *testing.T) { + tool := NewMessageTool() + tool.ConfigureLocalMedia(t.TempDir(), true, 1024*1024, nil) + params := tool.Parameters() + + props, ok := params["properties"].(map[string]any) + if !ok { + t.Fatal("Expected properties to be a map") + } + mediaProp, ok := props["media"].(map[string]any) + if !ok { + t.Fatal("Expected 'media' property") + } + if mediaProp["type"] != "array" { + t.Error("Expected media type to be 'array'") + } + anyOf, ok := params["anyOf"].([]map[string]any) + if !ok || len(anyOf) != 2 { + t.Fatal("Expected anyOf content/media requirement") + } + if _, ok := params["required"]; ok { + t.Fatal("did not expect top-level required content when media is enabled") + } +} + +func TestMessageTool_Execute_WithMediaDisabled(t *testing.T) { + tool := NewMessageTool() + tool.SetSendCallback(func( + ctx context.Context, + channel, chatID, content, replyToMessageID string, + mediaParts []bus.MediaPart, + ) error { + t.Fatal("send callback should not run when message media is disabled") + return nil + }) + + ctx := WithToolContext(context.Background(), "telegram", "-1001") + result := tool.Execute(ctx, map[string]any{ + "media": []any{ + map[string]any{"path": "photo.jpg"}, + }, + }) + if !result.IsError { + t.Fatal("expected error when message media is disabled") + } + if result.ForLLM != "message media attachments are disabled; enable tools.message.media_enabled to send local media through message" { + t.Fatalf("unexpected error: %q", result.ForLLM) + } +} + func TestMessageTool_Execute_WithReplyToMessageID(t *testing.T) { tool := NewMessageTool() var sentReplyTo string - tool.SetSendCallback(func(ctx context.Context, channel, chatID, content, replyToMessageID string) error { + tool.SetSendCallback(func( + ctx context.Context, + channel, chatID, content, replyToMessageID string, + mediaParts []bus.MediaPart, + ) error { sentReplyTo = replyToMessageID return nil }) @@ -297,7 +379,11 @@ func TestMessageTool_Execute_PropagatesTurnSessionMetadata(t *testing.T) { var gotAgentID, gotSessionKey string var gotScope *session.SessionScope - tool.SetSendCallback(func(ctx context.Context, channel, chatID, content, replyToMessageID string) error { + tool.SetSendCallback(func( + ctx context.Context, + channel, chatID, content, replyToMessageID string, + mediaParts []bus.MediaPart, + ) error { gotAgentID = ToolAgentID(ctx) gotSessionKey = ToolSessionKey(ctx) gotScope = ToolSessionScope(ctx) @@ -329,3 +415,55 @@ func TestMessageTool_Execute_PropagatesTurnSessionMetadata(t *testing.T) { t.Fatalf("ToolSessionScope() = %+v, want chat scope", gotScope) } } + +func TestMessageTool_Execute_WithMedia(t *testing.T) { + tool := NewMessageTool() + store := media.NewFileMediaStore() + dir := t.TempDir() + imgPath := filepath.Join(dir, "photo.jpg") + if err := os.WriteFile(imgPath, []byte("fake image bytes"), 0o644); err != nil { + t.Fatalf("write image: %v", err) + } + tool.ConfigureLocalMedia(dir, true, 1024*1024, []*regexp.Regexp{}) + tool.SetMediaStore(store) + + var gotContent string + var gotParts []bus.MediaPart + tool.SetSendCallback(func( + ctx context.Context, + channel, chatID, content, replyToMessageID string, + mediaParts []bus.MediaPart, + ) error { + gotContent = content + gotParts = append([]bus.MediaPart(nil), mediaParts...) + return nil + }) + + ctx := WithToolContext(context.Background(), "telegram", "-1001") + result := tool.Execute(ctx, map[string]any{ + "content": "Caption text", + "media": []any{ + map[string]any{ + "path": imgPath, + }, + }, + }) + if result.IsError { + t.Fatalf("expected success, got error: %s", result.ForLLM) + } + if gotContent != "Caption text" { + t.Fatalf("content = %q, want Caption text", gotContent) + } + if len(gotParts) != 1 { + t.Fatalf("expected 1 media part, got %d", len(gotParts)) + } + if gotParts[0].Caption != "Caption text" { + t.Fatalf("first part caption = %q, want Caption text", gotParts[0].Caption) + } + if gotParts[0].Ref == "" { + t.Fatal("expected media ref to be populated") + } + if gotParts[0].Type == "" { + t.Fatal("expected media type to be inferred") + } +}