diff --git a/pkg/agent/agent_init.go b/pkg/agent/agent_init.go index 50f0227a1..14b3f8bfe 100644 --- a/pkg/agent/agent_init.go +++ b/pkg/agent/agent_init.go @@ -161,9 +161,16 @@ func registerSharedTools( // Message tool if cfg.Tools.IsToolEnabled("message") { messageTool := tools.NewMessageTool() + messageTool.ConfigureLocalMedia( + agent.Workspace, + cfg.Agents.Defaults.RestrictToWorkspace, + cfg.Agents.Defaults.GetMaxMediaSize(), + allowReadPaths, + ) messageTool.SetSendCallback(func( ctx context.Context, channel, chatID, content, replyToMessageID string, + mediaParts []bus.MediaPart, ) error { pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second) defer pubCancel() @@ -173,6 +180,15 @@ func registerSharedTools( tools.ToolSessionKey(ctx), tools.ToolSessionScope(ctx), ) + if len(mediaParts) > 0 { + return msgBus.PublishOutboundMedia(pubCtx, bus.OutboundMediaMessage{ + Context: outboundCtx, + AgentID: outboundAgentID, + SessionKey: outboundSessionKey, + Scope: outboundScope, + Parts: mediaParts, + }) + } return msgBus.PublishOutbound(pubCtx, bus.OutboundMessage{ Context: outboundCtx, AgentID: outboundAgentID, diff --git a/pkg/agent/agent_test.go b/pkg/agent/agent_test.go index 29c2b32ea..6196b3f0b 100644 --- a/pkg/agent/agent_test.go +++ b/pkg/agent/agent_test.go @@ -377,7 +377,11 @@ func TestPublishResponseIfNeeded_DismissesToolFeedbackWhenMessageToolAlreadySent t.Fatal("expected default agent") } mt := tools.NewMessageTool() - mt.SetSendCallback(func(ctx context.Context, channel, chatID, content, replyToMessageID string) error { + mt.SetSendCallback(func( + ctx context.Context, + channel, chatID, content, replyToMessageID string, + mediaParts []bus.MediaPart, + ) error { return nil }) defaultAgent.Tools.Register(mt) diff --git a/pkg/channels/feishu/feishu_64.go b/pkg/channels/feishu/feishu_64.go index d09c021c7..5fd28806c 100644 --- a/pkg/channels/feishu/feishu_64.go +++ b/pkg/channels/feishu/feishu_64.go @@ -497,10 +497,18 @@ func (c *FeishuChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMess return nil, fmt.Errorf("no media store available: %w", channels.ErrSendFailed) } + caption := firstMediaCaption(msg.Parts) + sentAny := false for _, part := range msg.Parts { if err := c.sendMediaPart(ctx, msg.ChatID, part, store); err != nil { return nil, err } + sentAny = true + } + if sentAny && caption != "" { + if _, err := c.sendText(ctx, msg.ChatID, caption); err != nil { + return nil, err + } } if hasTrackedMsg { @@ -557,6 +565,15 @@ func (c *FeishuChannel) sendMediaPart( return nil } +func firstMediaCaption(parts []bus.MediaPart) string { + for _, part := range parts { + if caption := strings.TrimSpace(part.Caption); caption != "" { + return caption + } + } + return "" +} + // --- Inbound message handling --- func (c *FeishuChannel) handleMessageReceive(ctx context.Context, event *larkim.P2MessageReceiveV1) error { diff --git a/pkg/channels/slack/slack.go b/pkg/channels/slack/slack.go index fa62a4605..566422fdb 100644 --- a/pkg/channels/slack/slack.go +++ b/pkg/channels/slack/slack.go @@ -171,6 +171,8 @@ func (c *SlackChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessa return nil, fmt.Errorf("no media store available: %w", channels.ErrSendFailed) } + caption := slackFirstMediaCaption(msg.Parts) + sentAny := false for _, part := range msg.Parts { localPath, err := store.Resolve(part.Ref) if err != nil { @@ -205,6 +207,17 @@ func (c *SlackChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessa }) return nil, fmt.Errorf("slack send media: %w", channels.ErrTemporary) } + sentAny = true + } + + if sentAny && caption != "" { + opts := []slack.MsgOption{slack.MsgOptionText(caption, false)} + if threadTS != "" { + opts = append(opts, slack.MsgOptionTS(threadTS)) + } + if _, _, err := c.api.PostMessageContext(ctx, channelID, opts...); err != nil { + return nil, fmt.Errorf("slack send media caption fallback: %w", channels.ErrTemporary) + } } // UploadFile does not expose the posted message timestamp in its @@ -212,6 +225,15 @@ func (c *SlackChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessa return nil, nil } +func slackFirstMediaCaption(parts []bus.MediaPart) string { + for _, part := range parts { + if caption := strings.TrimSpace(part.Caption); caption != "" { + return caption + } + } + return "" +} + // ReactToMessage implements channels.ReactionCapable. // It adds an "eyes" (👀) reaction to the inbound message and returns an undo function // that removes the reaction. diff --git a/pkg/channels/telegram/telegram.go b/pkg/channels/telegram/telegram.go index 45672e5ee..1a57fc6ed 100644 --- a/pkg/channels/telegram/telegram.go +++ b/pkg/channels/telegram/telegram.go @@ -45,6 +45,7 @@ var ( ) const defaultMediaGroupDelay = 500 * time.Millisecond +const telegramCaptionLimit = 1024 type TelegramChannel struct { *channels.BaseChannel @@ -639,6 +640,34 @@ func (c *TelegramChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMe } var messageIDs []string + leadingCaption := telegramLeadingCaption(msg.Parts) + if len([]rune(leadingCaption)) > telegramCaptionLimit { + leadingIDs, leadingErr := c.sendCaptionText(ctx, chatID, threadID, leadingCaption) + if leadingErr != nil { + return nil, leadingErr + } + messageIDs = append(messageIDs, leadingIDs...) + msg = telegramClearMediaCaptions(msg) + } + + if len(msg.Parts) > 1 && telegramCanSendMediaGroup(msg.Parts) { + groupIDs, err := c.sendImageMediaGroups(ctx, chatID, threadID, store, msg.Parts) + if err != nil { + logger.ErrorCF("telegram", "Failed to send media group", map[string]any{ + "count": len(msg.Parts), + "error": err.Error(), + }) + return nil, fmt.Errorf("telegram send media group: %w", channels.ErrTemporary) + } + if len(groupIDs) > 0 { + messageIDs = append(messageIDs, groupIDs...) + if hasTrackedMsg { + c.dismissTrackedToolFeedbackMessage(ctx, trackedChatID, trackedMsgID) + } + return messageIDs, nil + } + } + for _, part := range msg.Parts { localPath, err := store.Resolve(part.Ref) if err != nil { @@ -742,6 +771,154 @@ func (c *TelegramChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMe return messageIDs, nil } +func telegramCanSendMediaGroup(parts []bus.MediaPart) bool { + if len(parts) < 2 { + return false + } + for _, part := range parts { + if part.Type != "image" { + return false + } + } + return true +} + +func (c *TelegramChannel) sendImageMediaGroups( + ctx context.Context, + chatID int64, + threadID int, + store media.MediaStore, + parts []bus.MediaPart, +) ([]string, error) { + const maxGroupSize = 10 + + messageIDs := make([]string, 0, len(parts)) + for start := 0; start < len(parts); start += maxGroupSize { + end := start + maxGroupSize + if end > len(parts) { + end = len(parts) + } + groupIDs, err := c.sendSingleImageMediaGroup(ctx, chatID, threadID, store, parts[start:end]) + if err != nil { + return nil, err + } + messageIDs = append(messageIDs, groupIDs...) + } + return messageIDs, nil +} + +func (c *TelegramChannel) sendSingleImageMediaGroup( + ctx context.Context, + chatID int64, + threadID int, + store media.MediaStore, + parts []bus.MediaPart, +) ([]string, error) { + opened := make([]*os.File, 0, len(parts)) + defer func() { + for _, file := range opened { + file.Close() + } + }() + + inputMedia := make([]telego.InputMedia, 0, len(parts)) + for i, part := range parts { + localPath, err := store.Resolve(part.Ref) + if err != nil { + logger.ErrorCF("telegram", "Failed to resolve media ref for media group", map[string]any{ + "ref": part.Ref, + "error": err.Error(), + }) + return nil, err + } + + file, err := os.Open(localPath) + if err != nil { + logger.ErrorCF("telegram", "Failed to open media file for media group", map[string]any{ + "path": localPath, + "error": err.Error(), + }) + return nil, err + } + opened = append(opened, file) + + mediaItem := &telego.InputMediaPhoto{ + Type: telego.MediaTypePhoto, + Media: telego.InputFile{File: file}, + } + if i == 0 { + mediaItem.Caption = part.Caption + } + inputMedia = append(inputMedia, mediaItem) + } + + results, err := c.bot.SendMediaGroup(ctx, &telego.SendMediaGroupParams{ + ChatID: tu.ID(chatID), + MessageThreadID: threadID, + Media: inputMedia, + }) + if err != nil { + return nil, err + } + + messageIDs := make([]string, 0, len(results)) + for _, result := range results { + messageIDs = append(messageIDs, strconv.Itoa(result.MessageID)) + } + return messageIDs, nil +} + +func (c *TelegramChannel) sendCaptionText( + ctx context.Context, + chatID int64, + threadID int, + text string, +) ([]string, error) { + text = strings.TrimSpace(text) + if text == "" { + return nil, nil + } + chunks := channels.SplitMessage(text, c.MaxMessageLength()) + messageIDs := make([]string, 0, len(chunks)) + for _, chunk := range chunks { + chunk = strings.TrimSpace(chunk) + if chunk == "" { + continue + } + msgID, err := c.sendChunk(ctx, sendChunkParams{ + chatID: chatID, + threadID: threadID, + content: chunk, + mdFallback: chunk, + useMarkdownV2: false, + }) + if err != nil { + return nil, err + } + messageIDs = append(messageIDs, msgID) + } + return messageIDs, nil +} + +func telegramLeadingCaption(parts []bus.MediaPart) string { + if len(parts) == 0 { + return "" + } + return strings.TrimSpace(parts[0].Caption) +} + +func telegramClearMediaCaptions(msg bus.OutboundMediaMessage) bus.OutboundMediaMessage { + if len(msg.Parts) == 0 { + return msg + } + cloned := msg + cloned.Parts = append([]bus.MediaPart(nil), msg.Parts...) + for i := range cloned.Parts { + cloned.Parts[i].Caption = "" + } + return cloned +} + func (c *TelegramChannel) handleMessage(ctx context.Context, message *telego.Message) error { if message != nil && strings.TrimSpace(message.MediaGroupID) != "" { return c.bufferMediaGroupMessage(ctx, message) diff --git a/pkg/channels/telegram/telegram_test.go b/pkg/channels/telegram/telegram_test.go index 0ebde1328..db5a8e784 100644 --- a/pkg/channels/telegram/telegram_test.go +++ b/pkg/channels/telegram/telegram_test.go @@ -110,6 +110,17 @@ func successResponseWithMessageID(t *testing.T, messageID int) *ta.Response { return &ta.Response{Ok: true, Result: b} } +func successMediaGroupResponse(t *testing.T, messageIDs ...int) *ta.Response { + t.Helper() + messages := make([]telego.Message, 0, len(messageIDs)) + for _, messageID := range messageIDs { + messages = append(messages, telego.Message{MessageID: messageID}) + } + b, err := json.Marshal(messages) + require.NoError(t, err) + return &ta.Response{Ok: true, Result: b} +} + func successUserResponse(t *testing.T, user *telego.User) *ta.Response { t.Helper() b, err := json.Marshal(user) @@ -237,6 +248,260 @@ func TestSendMedia_ImageNonDimensionErrorDoesNotFallback(t *testing.T) { assert.NotContains(t, caller.calls[0].URL, "sendDocument") } +func TestSendMedia_MultipleImagesUseMediaGroup(t *testing.T) { + constructor := &multipartRecordingConstructor{} + caller := &stubCaller{ + callFn: func(ctx context.Context, url string, data *ta.RequestData) (*ta.Response, error) { + if strings.Contains(url, "sendMediaGroup") { + return successMediaGroupResponse(t, 101, 102), nil + } + t.Fatalf("unexpected API call: %s", url) + return nil, nil + }, + } + ch := newTestChannelWithConstructor(t, caller, constructor) + + store := media.NewFileMediaStore() + ch.SetMediaStore(store) + + tmpDir := t.TempDir() + firstPath := filepath.Join(tmpDir, "first.png") + secondPath := filepath.Join(tmpDir, "second.png") + require.NoError(t, os.WriteFile(firstPath, []byte("first-image"), 0o644)) + require.NoError(t, os.WriteFile(secondPath, []byte("second-image"), 0o644)) + + firstRef, err := store.Store(firstPath, media.MediaMeta{Filename: "first.png", ContentType: "image/png"}, "scope-1") + require.NoError(t, err) + secondRef, err := store.Store(secondPath, media.MediaMeta{Filename: "second.png", ContentType: "image/png"}, "scope-1") + require.NoError(t, err) + + ids, err := ch.SendMedia(context.Background(), bus.OutboundMediaMessage{ + ChatID: "12345", + Parts: []bus.MediaPart{ + {Type: "image", Ref: firstRef, Caption: "album caption"}, + {Type: "image", Ref: secondRef}, + }, + }) + + require.NoError(t, err) + assert.Equal(t, []string{"101", "102"}, ids) + require.Len(t, caller.calls, 1) + assert.Contains(t, caller.calls[0].URL, "sendMediaGroup") + require.Len(t, constructor.calls, 1) + require.Len(t, constructor.calls[0].FileSizes, 2) + + var mediaPayload []map[string]any + require.NoError(t, json.Unmarshal([]byte(constructor.calls[0].Parameters["media"]), &mediaPayload)) + require.Len(t, mediaPayload, 2) + assert.Equal(t, "album caption", mediaPayload[0]["caption"]) + _, hasSecondCaption := mediaPayload[1]["caption"] + assert.False(t, hasSecondCaption) +} + +func TestSendMedia_MoreThanTenImagesSplitIntoMediaGroups(t *testing.T) { + constructor := &multipartRecordingConstructor{} + callIndex := 0 + caller := &stubCaller{ + callFn: func(ctx context.Context, url string, data *ta.RequestData) (*ta.Response, error) { + if !strings.Contains(url, "sendMediaGroup") { + t.Fatalf("unexpected API call: %s", url) + } + callIndex++ + if callIndex == 1 { + return successMediaGroupResponse(t, 1001, 1002, 1003, 1004, 1005, 1006, 1007, 1008, 1009, 1010), nil + } + if callIndex == 2 { + return successMediaGroupResponse(t, 1011, 1012, 1013, 1014, 1015), nil + } + t.Fatalf("unexpected sendMediaGroup call #%d", callIndex) + return nil, nil + }, + } + ch := newTestChannelWithConstructor(t, caller, constructor) + + store := media.NewFileMediaStore() + ch.SetMediaStore(store) + + tmpDir := t.TempDir() + parts := make([]bus.MediaPart, 0, 15) + for i := 0; i < 15; i++ { + path := filepath.Join(tmpDir, "image-"+strconv.Itoa(i)+".png") + require.NoError(t, os.WriteFile(path, []byte("img-"+strconv.Itoa(i)), 0o644)) + ref, err := store.Store(path, media.MediaMeta{Filename: filepath.Base(path), ContentType: "image/png"}, "scope-1") + require.NoError(t, err) + part := bus.MediaPart{Type: "image", Ref: ref} + if i == 0 { + part.Caption = "long album caption" + } + parts = append(parts, part) + } + + ids, err := ch.SendMedia(context.Background(), bus.OutboundMediaMessage{ + ChatID: "12345", + Parts: parts, + }) + + require.NoError(t, err) + assert.Equal(t, []string{ + "1001", "1002", "1003", "1004", "1005", + "1006", "1007", "1008", "1009", "1010", + "1011", "1012", "1013", "1014", "1015", + }, ids) + require.Len(t, caller.calls, 2) + require.Len(t, constructor.calls, 2) +} + +func TestSendMedia_SingleImageLongCaptionSendsTextFirst(t *testing.T) { + constructor := &multipartRecordingConstructor{} + longCaption := strings.Repeat("a", telegramCaptionLimit) + " tail overflow" + caller := &stubCaller{ + callFn: func(ctx context.Context, url string, data *ta.RequestData) (*ta.Response, error) { + switch { + case strings.Contains(url, "sendMessage"): + return successResponseWithMessageID(t, 201), nil + case strings.Contains(url, "sendPhoto"): + return successResponseWithMessageID(t, 202), nil + default: + t.Fatalf("unexpected API call: %s", url) + return nil, nil + } + }, + } + ch := newTestChannelWithConstructor(t, caller, constructor) + + store := media.NewFileMediaStore() + ch.SetMediaStore(store) + + tmpDir := t.TempDir() + path := filepath.Join(tmpDir, "image.png") + require.NoError(t, os.WriteFile(path, []byte("img"), 0o644)) + ref, err := store.Store(path, media.MediaMeta{Filename: "image.png", ContentType: "image/png"}, "scope-1") + require.NoError(t, err) + + ids, err := ch.SendMedia(context.Background(), bus.OutboundMediaMessage{ + ChatID: "12345", + Parts: []bus.MediaPart{{ + Type: "image", + Ref: ref, + Caption: longCaption, + }}, + }) + + require.NoError(t, err) + assert.Equal(t, []string{"201", "202"}, ids) + require.Len(t, caller.calls, 2) + assert.Contains(t, caller.calls[0].URL, "sendMessage") + assert.Contains(t, caller.calls[1].URL, "sendPhoto") + assert.Equal(t, "", constructor.calls[0].Parameters["caption"]) +} + +func TestSendMedia_MediaGroupLongCaptionSendsTextFirst(t *testing.T) { + constructor := &multipartRecordingConstructor{} + longCaption := strings.Repeat("b", telegramCaptionLimit) + " trailing explanation" + caller := &stubCaller{ + callFn: func(ctx context.Context, url string, data *ta.RequestData) (*ta.Response, error) { + switch { + case strings.Contains(url, "sendMessage"): + return successResponseWithMessageID(t, 301), nil + case strings.Contains(url, "sendMediaGroup"): + return successMediaGroupResponse(t, 302, 303), nil + default: + t.Fatalf("unexpected API call: %s", url) + return nil, nil + } + }, + } + ch := newTestChannelWithConstructor(t, caller, constructor) + + store := media.NewFileMediaStore() + ch.SetMediaStore(store) + + tmpDir := t.TempDir() + firstPath := filepath.Join(tmpDir, "first.png") + secondPath := filepath.Join(tmpDir, "second.png") + require.NoError(t, os.WriteFile(firstPath, []byte("first-image"), 0o644)) + require.NoError(t, os.WriteFile(secondPath, []byte("second-image"), 0o644)) + + firstRef, err := store.Store(firstPath, media.MediaMeta{Filename: "first.png", ContentType: "image/png"}, "scope-1") + require.NoError(t, err) + secondRef, err := store.Store(secondPath, media.MediaMeta{Filename: "second.png", ContentType: "image/png"}, "scope-1") + require.NoError(t, err) + + ids, err := ch.SendMedia(context.Background(), bus.OutboundMediaMessage{ + ChatID: "12345", + Parts: []bus.MediaPart{ + {Type: "image", Ref: firstRef, Caption: longCaption}, + {Type: "image", Ref: secondRef}, + }, + }) + + require.NoError(t, err) + assert.Equal(t, []string{"301", "302", "303"}, ids) + require.Len(t, caller.calls, 2) + assert.Contains(t, caller.calls[0].URL, "sendMessage") + assert.Contains(t, caller.calls[1].URL, "sendMediaGroup") +} + +func TestSendMedia_MultiGroupLongCaptionSendsTextBeforeGroups(t *testing.T) { + constructor := &multipartRecordingConstructor{} + longCaption := strings.Repeat("c", telegramCaptionLimit) + " overflow before second album" + callOrder := make([]string, 0, 3) + caller := &stubCaller{ + callFn: func(ctx context.Context, url string, data *ta.RequestData) (*ta.Response, error) { + switch { + case strings.Contains(url, "sendMessage"): + callOrder = append(callOrder, "text") + return successResponseWithMessageID(t, 499), nil + case strings.Contains(url, "sendMediaGroup"): + callOrder = append(callOrder, "group") + if len(callOrder) == 2 { + return successMediaGroupResponse(t, 401, 402, 403, 404, 405, 406, 407, 408, 409, 410), nil + } + if len(callOrder) == 3 { + return successMediaGroupResponse(t, 411, 412, 413, 414, 415), nil + } + t.Fatalf("unexpected sendMediaGroup order: %v", callOrder) + return nil, nil + default: + t.Fatalf("unexpected API call: %s", url) + return nil, nil + } + }, + } + ch := newTestChannelWithConstructor(t, caller, constructor) + + store := media.NewFileMediaStore() + ch.SetMediaStore(store) + + tmpDir := t.TempDir() + parts := make([]bus.MediaPart, 0, 15) + for i := 0; i < 15; i++ { + path := filepath.Join(tmpDir, "image-"+strconv.Itoa(i)+".png") + require.NoError(t, os.WriteFile(path, []byte("img-"+strconv.Itoa(i)), 0o644)) + ref, err := store.Store(path, media.MediaMeta{Filename: filepath.Base(path), ContentType: "image/png"}, "scope-1") + require.NoError(t, err) + part := bus.MediaPart{Type: "image", Ref: ref} + if i == 0 { + part.Caption = longCaption + } + parts = append(parts, part) + } + + ids, err := ch.SendMedia(context.Background(), bus.OutboundMediaMessage{ + ChatID: "12345", + Parts: parts, + }) + + require.NoError(t, err) + assert.Equal(t, []string{ + "499", + "401", "402", "403", "404", "405", + "406", "407", "408", "409", "410", + "411", "412", "413", "414", "415", + }, ids) + assert.Equal(t, []string{"text", "group", "group"}, callOrder) +} + func TestSend_EmptyContent(t *testing.T) { caller := &stubCaller{ callFn: func(ctx context.Context, url string, data *ta.RequestData) (*ta.Response, error) { diff --git a/pkg/tools/integration/message.go b/pkg/tools/integration/message.go index 98d87bcb3..f7b7b7fdc 100644 --- a/pkg/tools/integration/message.go +++ b/pkg/tools/integration/message.go @@ -3,10 +3,32 @@ package integrationtools import ( "context" "fmt" + "mime" + "os" + "path/filepath" + "regexp" + "strings" "sync" + + "github.com/h2non/filetype" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/media" + fstools "github.com/sipeed/picoclaw/pkg/tools/fs" ) -type SendCallbackWithContext func(ctx context.Context, channel, chatID, content, replyToMessageID string) error +type SendCallbackWithContext func( + ctx context.Context, + channel, chatID, content, replyToMessageID string, + mediaParts []bus.MediaPart, +) error + +type messageMediaArg struct { + Path string + Type string + Filename string +} // sentTarget records the channel+chatID that the message tool sent to. type sentTarget struct { @@ -16,10 +38,13 @@ type sentTarget struct { type MessageTool struct { sendCallback SendCallbackWithContext + workspace string + restrict bool + maxFileSize int + mediaStore media.MediaStore + allowPaths []*regexp.Regexp mu sync.Mutex - // sentTargets tracks targets sent to in the current round, keyed by session key - // to support parallel turns for different sessions. - sentTargets map[string][]sentTarget + sentTargets map[string][]sentTarget } func NewMessageTool() *MessageTool { @@ -33,7 +58,7 @@ func (t *MessageTool) Name() string { } func (t *MessageTool) Description() string { - return "Send a message to user on a chat channel. Use this when you want to communicate something." + return "Send a message to the user on a chat channel. Supports text-only, media-only, or text with media attachments." } func (t *MessageTool) Parameters() map[string]any { @@ -42,7 +67,29 @@ func (t *MessageTool) Parameters() map[string]any { "properties": map[string]any{ "content": map[string]any{ "type": "string", - "description": "The message content to send", + "description": "Optional message text. When media is present, this text is used as the caption/body for the media message.", + }, + "media": map[string]any{ + "type": "array", + "description": "Optional local media attachments to send with the message.", + "items": map[string]any{ + "type": "object", + "properties": map[string]any{ + "path": map[string]any{ + "type": "string", + "description": "Path to the local file. Relative paths are resolved from workspace.", + }, + "type": map[string]any{ + "type": "string", + "description": "Optional media type hint: image, audio, video, or file.", + }, + "filename": map[string]any{ + "type": "string", + "description": "Optional display filename. Defaults to the basename of path.", + }, + }, + "required": []string{"path"}, + }, }, "channel": map[string]any{ "type": "string", @@ -57,10 +104,32 @@ func (t *MessageTool) Parameters() map[string]any { "description": "Optional: reply target message ID for channels that support threaded replies", }, }, - "required": []string{"content"}, + "anyOf": []map[string]any{ + {"required": []string{"content"}}, + {"required": []string{"media"}}, + }, } } +func (t *MessageTool) ConfigureLocalMedia( + workspace string, + restrict bool, + maxFileSize int, + allowPaths []*regexp.Regexp, +) { + t.workspace = workspace + t.restrict = restrict + if maxFileSize <= 0 { + maxFileSize = config.DefaultMaxMediaSize + } + t.maxFileSize = maxFileSize + t.allowPaths = allowPaths +} + +func (t *MessageTool) SetMediaStore(store media.MediaStore) { + t.mediaStore = store +} + // ResetSentInRound resets the per-round send tracker for the given session key. // Called by the agent loop at the start of each inbound message processing round. func (t *MessageTool) ResetSentInRound(sessionKey string) { @@ -98,9 +167,14 @@ func (t *MessageTool) SetSendCallback(callback SendCallbackWithContext) { } func (t *MessageTool) Execute(ctx context.Context, args map[string]any) *ToolResult { - content, ok := args["content"].(string) - if !ok { - return &ToolResult{ForLLM: "content is required", IsError: true} + content, _ := args["content"].(string) + content = strings.TrimSpace(content) + mediaArgs, err := parseMessageMediaArgs(args["media"]) + if err != nil { + return &ToolResult{ForLLM: err.Error(), IsError: true} + } + if content == "" && len(mediaArgs) == 0 { + return &ToolResult{ForLLM: "content or media is required", IsError: true} } channel, _ := args["channel"].(string) @@ -122,7 +196,12 @@ func (t *MessageTool) Execute(ctx context.Context, args map[string]any) *ToolRes return &ToolResult{ForLLM: "Message sending not configured", IsError: true} } - if err := t.sendCallback(ctx, channel, chatID, content, replyToMessageID); err != nil { + parts, err := t.buildMediaParts(channel, chatID, content, mediaArgs) + if err != nil { + return &ToolResult{ForLLM: err.Error(), IsError: true, Err: err} + } + + if err := t.sendCallback(ctx, channel, chatID, content, replyToMessageID, parts); err != nil { return &ToolResult{ ForLLM: fmt.Sprintf("sending message: %v", err), IsError: true, @@ -135,9 +214,146 @@ func (t *MessageTool) Execute(ctx context.Context, args map[string]any) *ToolRes t.sentTargets[sessionKey] = append(t.sentTargets[sessionKey], sentTarget{Channel: channel, ChatID: chatID}) t.mu.Unlock() - // Silent: user already received the message directly + status := fmt.Sprintf("Message sent to %s:%s", channel, chatID) + if len(parts) > 0 { + status = fmt.Sprintf("Message with %d media attachment(s) sent to %s:%s", len(parts), channel, chatID) + } + return &ToolResult{ - ForLLM: fmt.Sprintf("Message sent to %s:%s", channel, chatID), + ForLLM: status, Silent: true, } } + +func parseMessageMediaArgs(raw any) ([]messageMediaArg, error) { + if raw == nil { + return nil, nil + } + items, ok := raw.([]any) + if !ok { + return nil, fmt.Errorf("media must be an array") + } + result := make([]messageMediaArg, 0, len(items)) + for i, item := range items { + obj, ok := item.(map[string]any) + if !ok { + return nil, fmt.Errorf("media[%d] must be an object", i) + } + path, _ := obj["path"].(string) + path = strings.TrimSpace(path) + if path == "" { + return nil, fmt.Errorf("media[%d].path is required", i) + } + typ, _ := obj["type"].(string) + filename, _ := obj["filename"].(string) + result = append(result, messageMediaArg{ + Path: path, + Type: strings.TrimSpace(typ), + Filename: strings.TrimSpace(filename), + }) + } + return result, nil +} + +func (t *MessageTool) buildMediaParts( + channel, chatID, content string, + mediaArgs []messageMediaArg, +) ([]bus.MediaPart, error) { + if len(mediaArgs) == 0 { + return nil, nil + } + if t.mediaStore == nil { + return nil, fmt.Errorf("media store not configured") + } + if strings.TrimSpace(t.workspace) == "" { + return nil, fmt.Errorf("message media delivery is not configured") + } + + scope := fmt.Sprintf("tool:message:%s:%s", channel, chatID) + parts := make([]bus.MediaPart, 0, len(mediaArgs)) + for i, item := range mediaArgs { + resolved, err := fstools.ValidatePathWithAllowPaths(item.Path, t.workspace, t.restrict, t.allowPaths) + if err != nil { + return nil, fmt.Errorf("invalid media[%d].path: %w", i, err) + } + info, err := os.Stat(resolved) + if err != nil { + return nil, fmt.Errorf("media[%d] file not found: %w", i, err) + } + if info.IsDir() { + return nil, fmt.Errorf("media[%d] path is a directory, expected a file", i) + } + if t.maxFileSize > 0 && info.Size() > int64(t.maxFileSize) { + return nil, fmt.Errorf("media[%d] file too large: %d bytes (max %d bytes)", i, info.Size(), t.maxFileSize) + } + + filename := item.Filename + if filename == "" { + filename = filepath.Base(resolved) + } + contentType := detectMessageMediaType(resolved) + partType := normalizeMessageMediaType(item.Type, filename, contentType) + ref, err := t.mediaStore.Store(resolved, media.MediaMeta{ + Filename: filename, + ContentType: contentType, + Source: "tool:message", + CleanupPolicy: media.CleanupPolicyForgetOnly, + }, scope) + if err != nil { + return nil, fmt.Errorf("failed to register media[%d]: %w", i, err) + } + + part := bus.MediaPart{ + Type: partType, + Ref: ref, + Filename: filename, + ContentType: contentType, + } + if i == 0 && content != "" { + part.Caption = content + } + parts = append(parts, part) + } + return parts, nil +} + +func detectMessageMediaType(path string) string { + kind, err := filetype.MatchFile(path) + if err == nil && kind != filetype.Unknown { + return kind.MIME.Value + } + if ext := filepath.Ext(path); ext != "" { + if t := mime.TypeByExtension(ext); t != "" { + return t + } + } + return "application/octet-stream" +} + +func normalizeMessageMediaType(typeHint, filename, contentType string) string { + switch strings.ToLower(strings.TrimSpace(typeHint)) { + case "image", "audio", "video", "file": + return strings.ToLower(strings.TrimSpace(typeHint)) + } + + ct := strings.ToLower(strings.TrimSpace(contentType)) + switch { + case strings.HasPrefix(ct, "image/"): + return "image" + case strings.HasPrefix(ct, "audio/"): + return "audio" + case strings.HasPrefix(ct, "video/"): + return "video" + } + + switch strings.ToLower(filepath.Ext(filename)) { + case ".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp": + return "image" + case ".mp3", ".wav", ".ogg", ".oga", ".m4a", ".flac": + return "audio" + case ".mp4", ".mov", ".mkv", ".webm", ".avi": + return "video" + default: + return "file" + } +} diff --git a/pkg/tools/integration/message_test.go b/pkg/tools/integration/message_test.go index c7b7d2b6e..2d3329d3d 100644 --- a/pkg/tools/integration/message_test.go +++ b/pkg/tools/integration/message_test.go @@ -3,8 +3,13 @@ package integrationtools import ( "context" "errors" + "os" + "path/filepath" + "regexp" "testing" + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/media" "github.com/sipeed/picoclaw/pkg/session" ) @@ -12,10 +17,17 @@ func TestMessageTool_Execute_Success(t *testing.T) { tool := NewMessageTool() var sentChannel, sentChatID, sentContent string - tool.SetSendCallback(func(ctx context.Context, channel, chatID, content, replyToMessageID string) error { + tool.SetSendCallback(func( + ctx context.Context, + channel, chatID, content, replyToMessageID string, + mediaParts []bus.MediaPart, + ) error { sentChannel = channel sentChatID = chatID sentContent = content + if len(mediaParts) != 0 { + t.Fatalf("expected no media parts, got %d", len(mediaParts)) + } if ToolAgentID(ctx) != "" || ToolSessionKey(ctx) != "" || ToolSessionScope(ctx) != nil { t.Fatalf("expected empty turn metadata in basic context, got agent=%q session=%q scope=%+v", ToolAgentID(ctx), ToolSessionKey(ctx), ToolSessionScope(ctx)) @@ -67,7 +79,11 @@ func TestMessageTool_Execute_WithCustomChannel(t *testing.T) { tool := NewMessageTool() var sentChannel, sentChatID string - tool.SetSendCallback(func(ctx context.Context, channel, chatID, content, replyToMessageID string) error { + tool.SetSendCallback(func( + ctx context.Context, + channel, chatID, content, replyToMessageID string, + mediaParts []bus.MediaPart, + ) error { sentChannel = channel sentChatID = chatID return nil @@ -102,7 +118,11 @@ func TestMessageTool_Execute_SendFailure(t *testing.T) { tool := NewMessageTool() sendErr := errors.New("network error") - tool.SetSendCallback(func(ctx context.Context, channel, chatID, content, replyToMessageID string) error { + tool.SetSendCallback(func( + ctx context.Context, + channel, chatID, content, replyToMessageID string, + mediaParts []bus.MediaPart, + ) error { return sendErr }) @@ -142,12 +162,12 @@ func TestMessageTool_Execute_MissingContent(t *testing.T) { result := tool.Execute(ctx, args) - // Verify error result for missing content + // Verify error result for missing content/media if !result.IsError { - t.Error("Expected IsError=true for missing content") + t.Error("Expected IsError=true for missing content/media") } - if result.ForLLM != "content is required" { - t.Errorf("Expected ForLLM 'content is required', got '%s'", result.ForLLM) + if result.ForLLM != "content or media is required" { + t.Errorf("Expected ForLLM 'content or media is required', got '%s'", result.ForLLM) } } @@ -155,7 +175,11 @@ func TestMessageTool_Execute_NoTargetChannel(t *testing.T) { tool := NewMessageTool() // No WithToolContext — channel/chatID are empty - tool.SetSendCallback(func(ctx context.Context, channel, chatID, content, replyToMessageID string) error { + tool.SetSendCallback(func( + ctx context.Context, + channel, chatID, content, replyToMessageID string, + mediaParts []bus.MediaPart, + ) error { return nil }) @@ -226,9 +250,9 @@ func TestMessageTool_Parameters(t *testing.T) { } // Check required properties - required, ok := params["required"].([]string) - if !ok || len(required) != 1 || required[0] != "content" { - t.Error("Expected 'content' to be required") + anyOf, ok := params["anyOf"].([]map[string]any) + if !ok || len(anyOf) != 2 { + t.Fatal("Expected anyOf content/media requirement") } // Check content property @@ -240,6 +264,14 @@ func TestMessageTool_Parameters(t *testing.T) { t.Error("Expected content type to be 'string'") } + mediaProp, ok := props["media"].(map[string]any) + if !ok { + t.Fatal("Expected 'media' property") + } + if mediaProp["type"] != "array" { + t.Error("Expected media type to be 'array'") + } + // Check channel property (optional) channelProp, ok := props["channel"].(map[string]any) if !ok { @@ -272,7 +304,11 @@ func TestMessageTool_Execute_WithReplyToMessageID(t *testing.T) { tool := NewMessageTool() var sentReplyTo string - tool.SetSendCallback(func(ctx context.Context, channel, chatID, content, replyToMessageID string) error { + tool.SetSendCallback(func( + ctx context.Context, + channel, chatID, content, replyToMessageID string, + mediaParts []bus.MediaPart, + ) error { sentReplyTo = replyToMessageID return nil }) @@ -297,7 +333,11 @@ func TestMessageTool_Execute_PropagatesTurnSessionMetadata(t *testing.T) { var gotAgentID, gotSessionKey string var gotScope *session.SessionScope - tool.SetSendCallback(func(ctx context.Context, channel, chatID, content, replyToMessageID string) error { + tool.SetSendCallback(func( + ctx context.Context, + channel, chatID, content, replyToMessageID string, + mediaParts []bus.MediaPart, + ) error { gotAgentID = ToolAgentID(ctx) gotSessionKey = ToolSessionKey(ctx) gotScope = ToolSessionScope(ctx) @@ -329,3 +369,55 @@ func TestMessageTool_Execute_PropagatesTurnSessionMetadata(t *testing.T) { t.Fatalf("ToolSessionScope() = %+v, want chat scope", gotScope) } } + +func TestMessageTool_Execute_WithMedia(t *testing.T) { + tool := NewMessageTool() + store := media.NewFileMediaStore() + dir := t.TempDir() + imgPath := filepath.Join(dir, "photo.jpg") + if err := os.WriteFile(imgPath, []byte("fake image bytes"), 0o644); err != nil { + t.Fatalf("write image: %v", err) + } + tool.ConfigureLocalMedia(dir, true, 1024*1024, []*regexp.Regexp{}) + tool.SetMediaStore(store) + + var gotContent string + var gotParts []bus.MediaPart + tool.SetSendCallback(func( + ctx context.Context, + channel, chatID, content, replyToMessageID string, + mediaParts []bus.MediaPart, + ) error { + gotContent = content + gotParts = append([]bus.MediaPart(nil), mediaParts...) + return nil + }) + + ctx := WithToolContext(context.Background(), "telegram", "-1001") + result := tool.Execute(ctx, map[string]any{ + "content": "Caption text", + "media": []any{ + map[string]any{ + "path": imgPath, + }, + }, + }) + if result.IsError { + t.Fatalf("expected success, got error: %s", result.ForLLM) + } + if gotContent != "Caption text" { + t.Fatalf("content = %q, want Caption text", gotContent) + } + if len(gotParts) != 1 { + t.Fatalf("expected 1 media part, got %d", len(gotParts)) + } + if gotParts[0].Caption != "Caption text" { + t.Fatalf("first part caption = %q, want Caption text", gotParts[0].Caption) + } + if gotParts[0].Ref == "" { + t.Fatal("expected media ref to be populated") + } + if gotParts[0].Type == "" { + t.Fatal("expected media type to be inferred") + } +}