feat(message): support media attachments in outbound tool

This commit is contained in:
Anton Bogdanovich
2026-05-11 16:04:26 -07:00
parent f09a7d67f7
commit 5a4e42d1b6
8 changed files with 836 additions and 27 deletions
+16
View File
@@ -161,9 +161,16 @@ func registerSharedTools(
// Message tool
if cfg.Tools.IsToolEnabled("message") {
messageTool := tools.NewMessageTool()
messageTool.ConfigureLocalMedia(
agent.Workspace,
cfg.Agents.Defaults.RestrictToWorkspace,
cfg.Agents.Defaults.GetMaxMediaSize(),
allowReadPaths,
)
messageTool.SetSendCallback(func(
ctx context.Context,
channel, chatID, content, replyToMessageID string,
mediaParts []bus.MediaPart,
) error {
pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second)
defer pubCancel()
@@ -173,6 +180,15 @@ func registerSharedTools(
tools.ToolSessionKey(ctx),
tools.ToolSessionScope(ctx),
)
if len(mediaParts) > 0 {
return msgBus.PublishOutboundMedia(pubCtx, bus.OutboundMediaMessage{
Context: outboundCtx,
AgentID: outboundAgentID,
SessionKey: outboundSessionKey,
Scope: outboundScope,
Parts: mediaParts,
})
}
return msgBus.PublishOutbound(pubCtx, bus.OutboundMessage{
Context: outboundCtx,
AgentID: outboundAgentID,
+5 -1
View File
@@ -377,7 +377,11 @@ func TestPublishResponseIfNeeded_DismissesToolFeedbackWhenMessageToolAlreadySent
t.Fatal("expected default agent")
}
mt := tools.NewMessageTool()
mt.SetSendCallback(func(ctx context.Context, channel, chatID, content, replyToMessageID string) error {
mt.SetSendCallback(func(
ctx context.Context,
channel, chatID, content, replyToMessageID string,
mediaParts []bus.MediaPart,
) error {
return nil
})
defaultAgent.Tools.Register(mt)
+17
View File
@@ -497,10 +497,18 @@ func (c *FeishuChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMess
return nil, fmt.Errorf("no media store available: %w", channels.ErrSendFailed)
}
caption := firstMediaCaption(msg.Parts)
sentAny := false
for _, part := range msg.Parts {
if err := c.sendMediaPart(ctx, msg.ChatID, part, store); err != nil {
return nil, err
}
sentAny = true
}
if sentAny && caption != "" {
if _, err := c.sendText(ctx, msg.ChatID, caption); err != nil {
return nil, err
}
}
if hasTrackedMsg {
@@ -557,6 +565,15 @@ func (c *FeishuChannel) sendMediaPart(
return nil
}
func firstMediaCaption(parts []bus.MediaPart) string {
for _, part := range parts {
if caption := strings.TrimSpace(part.Caption); caption != "" {
return caption
}
}
return ""
}
// --- Inbound message handling ---
func (c *FeishuChannel) handleMessageReceive(ctx context.Context, event *larkim.P2MessageReceiveV1) error {
+22
View File
@@ -171,6 +171,8 @@ func (c *SlackChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessa
return nil, fmt.Errorf("no media store available: %w", channels.ErrSendFailed)
}
caption := slackFirstMediaCaption(msg.Parts)
sentAny := false
for _, part := range msg.Parts {
localPath, err := store.Resolve(part.Ref)
if err != nil {
@@ -205,6 +207,17 @@ func (c *SlackChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessa
})
return nil, fmt.Errorf("slack send media: %w", channels.ErrTemporary)
}
sentAny = true
}
if sentAny && caption != "" {
opts := []slack.MsgOption{slack.MsgOptionText(caption, false)}
if threadTS != "" {
opts = append(opts, slack.MsgOptionTS(threadTS))
}
if _, _, err := c.api.PostMessageContext(ctx, channelID, opts...); err != nil {
return nil, fmt.Errorf("slack send media caption fallback: %w", channels.ErrTemporary)
}
}
// UploadFile does not expose the posted message timestamp in its
@@ -212,6 +225,15 @@ func (c *SlackChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessa
return nil, nil
}
func slackFirstMediaCaption(parts []bus.MediaPart) string {
for _, part := range parts {
if caption := strings.TrimSpace(part.Caption); caption != "" {
return caption
}
}
return ""
}
// ReactToMessage implements channels.ReactionCapable.
// It adds an "eyes" (👀) reaction to the inbound message and returns an undo function
// that removes the reaction.
+177
View File
@@ -45,6 +45,7 @@ var (
)
const defaultMediaGroupDelay = 500 * time.Millisecond
const telegramCaptionLimit = 1024
type TelegramChannel struct {
*channels.BaseChannel
@@ -639,6 +640,34 @@ func (c *TelegramChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMe
}
var messageIDs []string
leadingCaption := telegramLeadingCaption(msg.Parts)
if len([]rune(leadingCaption)) > telegramCaptionLimit {
leadingIDs, leadingErr := c.sendCaptionText(ctx, chatID, threadID, leadingCaption)
if leadingErr != nil {
return nil, leadingErr
}
messageIDs = append(messageIDs, leadingIDs...)
msg = telegramClearMediaCaptions(msg)
}
if len(msg.Parts) > 1 && telegramCanSendMediaGroup(msg.Parts) {
groupIDs, err := c.sendImageMediaGroups(ctx, chatID, threadID, store, msg.Parts)
if err != nil {
logger.ErrorCF("telegram", "Failed to send media group", map[string]any{
"count": len(msg.Parts),
"error": err.Error(),
})
return nil, fmt.Errorf("telegram send media group: %w", channels.ErrTemporary)
}
if len(groupIDs) > 0 {
messageIDs = append(messageIDs, groupIDs...)
if hasTrackedMsg {
c.dismissTrackedToolFeedbackMessage(ctx, trackedChatID, trackedMsgID)
}
return messageIDs, nil
}
}
for _, part := range msg.Parts {
localPath, err := store.Resolve(part.Ref)
if err != nil {
@@ -742,6 +771,154 @@ func (c *TelegramChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMe
return messageIDs, nil
}
func telegramCanSendMediaGroup(parts []bus.MediaPart) bool {
if len(parts) < 2 {
return false
}
for _, part := range parts {
if part.Type != "image" {
return false
}
}
return true
}
func (c *TelegramChannel) sendImageMediaGroups(
ctx context.Context,
chatID int64,
threadID int,
store media.MediaStore,
parts []bus.MediaPart,
) ([]string, error) {
const maxGroupSize = 10
messageIDs := make([]string, 0, len(parts))
for start := 0; start < len(parts); start += maxGroupSize {
end := start + maxGroupSize
if end > len(parts) {
end = len(parts)
}
groupIDs, err := c.sendSingleImageMediaGroup(ctx, chatID, threadID, store, parts[start:end])
if err != nil {
return nil, err
}
messageIDs = append(messageIDs, groupIDs...)
}
return messageIDs, nil
}
func (c *TelegramChannel) sendSingleImageMediaGroup(
ctx context.Context,
chatID int64,
threadID int,
store media.MediaStore,
parts []bus.MediaPart,
) ([]string, error) {
opened := make([]*os.File, 0, len(parts))
defer func() {
for _, file := range opened {
file.Close()
}
}()
inputMedia := make([]telego.InputMedia, 0, len(parts))
for i, part := range parts {
localPath, err := store.Resolve(part.Ref)
if err != nil {
logger.ErrorCF("telegram", "Failed to resolve media ref for media group", map[string]any{
"ref": part.Ref,
"error": err.Error(),
})
return nil, err
}
file, err := os.Open(localPath)
if err != nil {
logger.ErrorCF("telegram", "Failed to open media file for media group", map[string]any{
"path": localPath,
"error": err.Error(),
})
return nil, err
}
opened = append(opened, file)
mediaItem := &telego.InputMediaPhoto{
Type: telego.MediaTypePhoto,
Media: telego.InputFile{File: file},
}
if i == 0 {
mediaItem.Caption = part.Caption
}
inputMedia = append(inputMedia, mediaItem)
}
results, err := c.bot.SendMediaGroup(ctx, &telego.SendMediaGroupParams{
ChatID: tu.ID(chatID),
MessageThreadID: threadID,
Media: inputMedia,
})
if err != nil {
return nil, err
}
messageIDs := make([]string, 0, len(results))
for _, result := range results {
messageIDs = append(messageIDs, strconv.Itoa(result.MessageID))
}
return messageIDs, nil
}
func (c *TelegramChannel) sendCaptionText(
ctx context.Context,
chatID int64,
threadID int,
text string,
) ([]string, error) {
text = strings.TrimSpace(text)
if text == "" {
return nil, nil
}
chunks := channels.SplitMessage(text, c.MaxMessageLength())
messageIDs := make([]string, 0, len(chunks))
for _, chunk := range chunks {
chunk = strings.TrimSpace(chunk)
if chunk == "" {
continue
}
msgID, err := c.sendChunk(ctx, sendChunkParams{
chatID: chatID,
threadID: threadID,
content: chunk,
mdFallback: chunk,
useMarkdownV2: false,
})
if err != nil {
return nil, err
}
messageIDs = append(messageIDs, msgID)
}
return messageIDs, nil
}
func telegramLeadingCaption(parts []bus.MediaPart) string {
if len(parts) == 0 {
return ""
}
return strings.TrimSpace(parts[0].Caption)
}
func telegramClearMediaCaptions(msg bus.OutboundMediaMessage) bus.OutboundMediaMessage {
if len(msg.Parts) == 0 {
return msg
}
cloned := msg
cloned.Parts = append([]bus.MediaPart(nil), msg.Parts...)
for i := range cloned.Parts {
cloned.Parts[i].Caption = ""
}
return cloned
}
func (c *TelegramChannel) handleMessage(ctx context.Context, message *telego.Message) error {
if message != nil && strings.TrimSpace(message.MediaGroupID) != "" {
return c.bufferMediaGroupMessage(ctx, message)
+265
View File
@@ -110,6 +110,17 @@ func successResponseWithMessageID(t *testing.T, messageID int) *ta.Response {
return &ta.Response{Ok: true, Result: b}
}
func successMediaGroupResponse(t *testing.T, messageIDs ...int) *ta.Response {
t.Helper()
messages := make([]telego.Message, 0, len(messageIDs))
for _, messageID := range messageIDs {
messages = append(messages, telego.Message{MessageID: messageID})
}
b, err := json.Marshal(messages)
require.NoError(t, err)
return &ta.Response{Ok: true, Result: b}
}
func successUserResponse(t *testing.T, user *telego.User) *ta.Response {
t.Helper()
b, err := json.Marshal(user)
@@ -237,6 +248,260 @@ func TestSendMedia_ImageNonDimensionErrorDoesNotFallback(t *testing.T) {
assert.NotContains(t, caller.calls[0].URL, "sendDocument")
}
func TestSendMedia_MultipleImagesUseMediaGroup(t *testing.T) {
constructor := &multipartRecordingConstructor{}
caller := &stubCaller{
callFn: func(ctx context.Context, url string, data *ta.RequestData) (*ta.Response, error) {
if strings.Contains(url, "sendMediaGroup") {
return successMediaGroupResponse(t, 101, 102), nil
}
t.Fatalf("unexpected API call: %s", url)
return nil, nil
},
}
ch := newTestChannelWithConstructor(t, caller, constructor)
store := media.NewFileMediaStore()
ch.SetMediaStore(store)
tmpDir := t.TempDir()
firstPath := filepath.Join(tmpDir, "first.png")
secondPath := filepath.Join(tmpDir, "second.png")
require.NoError(t, os.WriteFile(firstPath, []byte("first-image"), 0o644))
require.NoError(t, os.WriteFile(secondPath, []byte("second-image"), 0o644))
firstRef, err := store.Store(firstPath, media.MediaMeta{Filename: "first.png", ContentType: "image/png"}, "scope-1")
require.NoError(t, err)
secondRef, err := store.Store(secondPath, media.MediaMeta{Filename: "second.png", ContentType: "image/png"}, "scope-1")
require.NoError(t, err)
ids, err := ch.SendMedia(context.Background(), bus.OutboundMediaMessage{
ChatID: "12345",
Parts: []bus.MediaPart{
{Type: "image", Ref: firstRef, Caption: "album caption"},
{Type: "image", Ref: secondRef},
},
})
require.NoError(t, err)
assert.Equal(t, []string{"101", "102"}, ids)
require.Len(t, caller.calls, 1)
assert.Contains(t, caller.calls[0].URL, "sendMediaGroup")
require.Len(t, constructor.calls, 1)
require.Len(t, constructor.calls[0].FileSizes, 2)
var mediaPayload []map[string]any
require.NoError(t, json.Unmarshal([]byte(constructor.calls[0].Parameters["media"]), &mediaPayload))
require.Len(t, mediaPayload, 2)
assert.Equal(t, "album caption", mediaPayload[0]["caption"])
_, hasSecondCaption := mediaPayload[1]["caption"]
assert.False(t, hasSecondCaption)
}
func TestSendMedia_MoreThanTenImagesSplitIntoMediaGroups(t *testing.T) {
constructor := &multipartRecordingConstructor{}
callIndex := 0
caller := &stubCaller{
callFn: func(ctx context.Context, url string, data *ta.RequestData) (*ta.Response, error) {
if !strings.Contains(url, "sendMediaGroup") {
t.Fatalf("unexpected API call: %s", url)
}
callIndex++
if callIndex == 1 {
return successMediaGroupResponse(t, 1001, 1002, 1003, 1004, 1005, 1006, 1007, 1008, 1009, 1010), nil
}
if callIndex == 2 {
return successMediaGroupResponse(t, 1011, 1012, 1013, 1014, 1015), nil
}
t.Fatalf("unexpected sendMediaGroup call #%d", callIndex)
return nil, nil
},
}
ch := newTestChannelWithConstructor(t, caller, constructor)
store := media.NewFileMediaStore()
ch.SetMediaStore(store)
tmpDir := t.TempDir()
parts := make([]bus.MediaPart, 0, 15)
for i := 0; i < 15; i++ {
path := filepath.Join(tmpDir, "image-"+strconv.Itoa(i)+".png")
require.NoError(t, os.WriteFile(path, []byte("img-"+strconv.Itoa(i)), 0o644))
ref, err := store.Store(path, media.MediaMeta{Filename: filepath.Base(path), ContentType: "image/png"}, "scope-1")
require.NoError(t, err)
part := bus.MediaPart{Type: "image", Ref: ref}
if i == 0 {
part.Caption = "long album caption"
}
parts = append(parts, part)
}
ids, err := ch.SendMedia(context.Background(), bus.OutboundMediaMessage{
ChatID: "12345",
Parts: parts,
})
require.NoError(t, err)
assert.Equal(t, []string{
"1001", "1002", "1003", "1004", "1005",
"1006", "1007", "1008", "1009", "1010",
"1011", "1012", "1013", "1014", "1015",
}, ids)
require.Len(t, caller.calls, 2)
require.Len(t, constructor.calls, 2)
}
func TestSendMedia_SingleImageLongCaptionSendsTextFirst(t *testing.T) {
constructor := &multipartRecordingConstructor{}
longCaption := strings.Repeat("a", telegramCaptionLimit) + " tail overflow"
caller := &stubCaller{
callFn: func(ctx context.Context, url string, data *ta.RequestData) (*ta.Response, error) {
switch {
case strings.Contains(url, "sendMessage"):
return successResponseWithMessageID(t, 201), nil
case strings.Contains(url, "sendPhoto"):
return successResponseWithMessageID(t, 202), nil
default:
t.Fatalf("unexpected API call: %s", url)
return nil, nil
}
},
}
ch := newTestChannelWithConstructor(t, caller, constructor)
store := media.NewFileMediaStore()
ch.SetMediaStore(store)
tmpDir := t.TempDir()
path := filepath.Join(tmpDir, "image.png")
require.NoError(t, os.WriteFile(path, []byte("img"), 0o644))
ref, err := store.Store(path, media.MediaMeta{Filename: "image.png", ContentType: "image/png"}, "scope-1")
require.NoError(t, err)
ids, err := ch.SendMedia(context.Background(), bus.OutboundMediaMessage{
ChatID: "12345",
Parts: []bus.MediaPart{{
Type: "image",
Ref: ref,
Caption: longCaption,
}},
})
require.NoError(t, err)
assert.Equal(t, []string{"201", "202"}, ids)
require.Len(t, caller.calls, 2)
assert.Contains(t, caller.calls[0].URL, "sendMessage")
assert.Contains(t, caller.calls[1].URL, "sendPhoto")
assert.Equal(t, "", constructor.calls[0].Parameters["caption"])
}
func TestSendMedia_MediaGroupLongCaptionSendsTextFirst(t *testing.T) {
constructor := &multipartRecordingConstructor{}
longCaption := strings.Repeat("b", telegramCaptionLimit) + " trailing explanation"
caller := &stubCaller{
callFn: func(ctx context.Context, url string, data *ta.RequestData) (*ta.Response, error) {
switch {
case strings.Contains(url, "sendMessage"):
return successResponseWithMessageID(t, 301), nil
case strings.Contains(url, "sendMediaGroup"):
return successMediaGroupResponse(t, 302, 303), nil
default:
t.Fatalf("unexpected API call: %s", url)
return nil, nil
}
},
}
ch := newTestChannelWithConstructor(t, caller, constructor)
store := media.NewFileMediaStore()
ch.SetMediaStore(store)
tmpDir := t.TempDir()
firstPath := filepath.Join(tmpDir, "first.png")
secondPath := filepath.Join(tmpDir, "second.png")
require.NoError(t, os.WriteFile(firstPath, []byte("first-image"), 0o644))
require.NoError(t, os.WriteFile(secondPath, []byte("second-image"), 0o644))
firstRef, err := store.Store(firstPath, media.MediaMeta{Filename: "first.png", ContentType: "image/png"}, "scope-1")
require.NoError(t, err)
secondRef, err := store.Store(secondPath, media.MediaMeta{Filename: "second.png", ContentType: "image/png"}, "scope-1")
require.NoError(t, err)
ids, err := ch.SendMedia(context.Background(), bus.OutboundMediaMessage{
ChatID: "12345",
Parts: []bus.MediaPart{
{Type: "image", Ref: firstRef, Caption: longCaption},
{Type: "image", Ref: secondRef},
},
})
require.NoError(t, err)
assert.Equal(t, []string{"301", "302", "303"}, ids)
require.Len(t, caller.calls, 2)
assert.Contains(t, caller.calls[0].URL, "sendMessage")
assert.Contains(t, caller.calls[1].URL, "sendMediaGroup")
}
func TestSendMedia_MultiGroupLongCaptionSendsTextBeforeGroups(t *testing.T) {
constructor := &multipartRecordingConstructor{}
longCaption := strings.Repeat("c", telegramCaptionLimit) + " overflow before second album"
callOrder := make([]string, 0, 3)
caller := &stubCaller{
callFn: func(ctx context.Context, url string, data *ta.RequestData) (*ta.Response, error) {
switch {
case strings.Contains(url, "sendMessage"):
callOrder = append(callOrder, "text")
return successResponseWithMessageID(t, 499), nil
case strings.Contains(url, "sendMediaGroup"):
callOrder = append(callOrder, "group")
if len(callOrder) == 2 {
return successMediaGroupResponse(t, 401, 402, 403, 404, 405, 406, 407, 408, 409, 410), nil
}
if len(callOrder) == 3 {
return successMediaGroupResponse(t, 411, 412, 413, 414, 415), nil
}
t.Fatalf("unexpected sendMediaGroup order: %v", callOrder)
return nil, nil
default:
t.Fatalf("unexpected API call: %s", url)
return nil, nil
}
},
}
ch := newTestChannelWithConstructor(t, caller, constructor)
store := media.NewFileMediaStore()
ch.SetMediaStore(store)
tmpDir := t.TempDir()
parts := make([]bus.MediaPart, 0, 15)
for i := 0; i < 15; i++ {
path := filepath.Join(tmpDir, "image-"+strconv.Itoa(i)+".png")
require.NoError(t, os.WriteFile(path, []byte("img-"+strconv.Itoa(i)), 0o644))
ref, err := store.Store(path, media.MediaMeta{Filename: filepath.Base(path), ContentType: "image/png"}, "scope-1")
require.NoError(t, err)
part := bus.MediaPart{Type: "image", Ref: ref}
if i == 0 {
part.Caption = longCaption
}
parts = append(parts, part)
}
ids, err := ch.SendMedia(context.Background(), bus.OutboundMediaMessage{
ChatID: "12345",
Parts: parts,
})
require.NoError(t, err)
assert.Equal(t, []string{
"499",
"401", "402", "403", "404", "405",
"406", "407", "408", "409", "410",
"411", "412", "413", "414", "415",
}, ids)
assert.Equal(t, []string{"text", "group", "group"}, callOrder)
}
func TestSend_EmptyContent(t *testing.T) {
caller := &stubCaller{
callFn: func(ctx context.Context, url string, data *ta.RequestData) (*ta.Response, error) {
+229 -13
View File
@@ -3,10 +3,32 @@ package integrationtools
import (
"context"
"fmt"
"mime"
"os"
"path/filepath"
"regexp"
"strings"
"sync"
"github.com/h2non/filetype"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/media"
fstools "github.com/sipeed/picoclaw/pkg/tools/fs"
)
type SendCallbackWithContext func(ctx context.Context, channel, chatID, content, replyToMessageID string) error
type SendCallbackWithContext func(
ctx context.Context,
channel, chatID, content, replyToMessageID string,
mediaParts []bus.MediaPart,
) error
type messageMediaArg struct {
Path string
Type string
Filename string
}
// sentTarget records the channel+chatID that the message tool sent to.
type sentTarget struct {
@@ -16,10 +38,13 @@ type sentTarget struct {
type MessageTool struct {
sendCallback SendCallbackWithContext
workspace string
restrict bool
maxFileSize int
mediaStore media.MediaStore
allowPaths []*regexp.Regexp
mu sync.Mutex
// sentTargets tracks targets sent to in the current round, keyed by session key
// to support parallel turns for different sessions.
sentTargets map[string][]sentTarget
sentTargets map[string][]sentTarget
}
func NewMessageTool() *MessageTool {
@@ -33,7 +58,7 @@ func (t *MessageTool) Name() string {
}
func (t *MessageTool) Description() string {
return "Send a message to user on a chat channel. Use this when you want to communicate something."
return "Send a message to the user on a chat channel. Supports text-only, media-only, or text with media attachments."
}
func (t *MessageTool) Parameters() map[string]any {
@@ -42,7 +67,29 @@ func (t *MessageTool) Parameters() map[string]any {
"properties": map[string]any{
"content": map[string]any{
"type": "string",
"description": "The message content to send",
"description": "Optional message text. When media is present, this text is used as the caption/body for the media message.",
},
"media": map[string]any{
"type": "array",
"description": "Optional local media attachments to send with the message.",
"items": map[string]any{
"type": "object",
"properties": map[string]any{
"path": map[string]any{
"type": "string",
"description": "Path to the local file. Relative paths are resolved from workspace.",
},
"type": map[string]any{
"type": "string",
"description": "Optional media type hint: image, audio, video, or file.",
},
"filename": map[string]any{
"type": "string",
"description": "Optional display filename. Defaults to the basename of path.",
},
},
"required": []string{"path"},
},
},
"channel": map[string]any{
"type": "string",
@@ -57,10 +104,32 @@ func (t *MessageTool) Parameters() map[string]any {
"description": "Optional: reply target message ID for channels that support threaded replies",
},
},
"required": []string{"content"},
"anyOf": []map[string]any{
{"required": []string{"content"}},
{"required": []string{"media"}},
},
}
}
func (t *MessageTool) ConfigureLocalMedia(
workspace string,
restrict bool,
maxFileSize int,
allowPaths []*regexp.Regexp,
) {
t.workspace = workspace
t.restrict = restrict
if maxFileSize <= 0 {
maxFileSize = config.DefaultMaxMediaSize
}
t.maxFileSize = maxFileSize
t.allowPaths = allowPaths
}
func (t *MessageTool) SetMediaStore(store media.MediaStore) {
t.mediaStore = store
}
// ResetSentInRound resets the per-round send tracker for the given session key.
// Called by the agent loop at the start of each inbound message processing round.
func (t *MessageTool) ResetSentInRound(sessionKey string) {
@@ -98,9 +167,14 @@ func (t *MessageTool) SetSendCallback(callback SendCallbackWithContext) {
}
func (t *MessageTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
content, ok := args["content"].(string)
if !ok {
return &ToolResult{ForLLM: "content is required", IsError: true}
content, _ := args["content"].(string)
content = strings.TrimSpace(content)
mediaArgs, err := parseMessageMediaArgs(args["media"])
if err != nil {
return &ToolResult{ForLLM: err.Error(), IsError: true}
}
if content == "" && len(mediaArgs) == 0 {
return &ToolResult{ForLLM: "content or media is required", IsError: true}
}
channel, _ := args["channel"].(string)
@@ -122,7 +196,12 @@ func (t *MessageTool) Execute(ctx context.Context, args map[string]any) *ToolRes
return &ToolResult{ForLLM: "Message sending not configured", IsError: true}
}
if err := t.sendCallback(ctx, channel, chatID, content, replyToMessageID); err != nil {
parts, err := t.buildMediaParts(channel, chatID, content, mediaArgs)
if err != nil {
return &ToolResult{ForLLM: err.Error(), IsError: true, Err: err}
}
if err := t.sendCallback(ctx, channel, chatID, content, replyToMessageID, parts); err != nil {
return &ToolResult{
ForLLM: fmt.Sprintf("sending message: %v", err),
IsError: true,
@@ -135,9 +214,146 @@ func (t *MessageTool) Execute(ctx context.Context, args map[string]any) *ToolRes
t.sentTargets[sessionKey] = append(t.sentTargets[sessionKey], sentTarget{Channel: channel, ChatID: chatID})
t.mu.Unlock()
// Silent: user already received the message directly
status := fmt.Sprintf("Message sent to %s:%s", channel, chatID)
if len(parts) > 0 {
status = fmt.Sprintf("Message with %d media attachment(s) sent to %s:%s", len(parts), channel, chatID)
}
return &ToolResult{
ForLLM: fmt.Sprintf("Message sent to %s:%s", channel, chatID),
ForLLM: status,
Silent: true,
}
}
func parseMessageMediaArgs(raw any) ([]messageMediaArg, error) {
if raw == nil {
return nil, nil
}
items, ok := raw.([]any)
if !ok {
return nil, fmt.Errorf("media must be an array")
}
result := make([]messageMediaArg, 0, len(items))
for i, item := range items {
obj, ok := item.(map[string]any)
if !ok {
return nil, fmt.Errorf("media[%d] must be an object", i)
}
path, _ := obj["path"].(string)
path = strings.TrimSpace(path)
if path == "" {
return nil, fmt.Errorf("media[%d].path is required", i)
}
typ, _ := obj["type"].(string)
filename, _ := obj["filename"].(string)
result = append(result, messageMediaArg{
Path: path,
Type: strings.TrimSpace(typ),
Filename: strings.TrimSpace(filename),
})
}
return result, nil
}
func (t *MessageTool) buildMediaParts(
channel, chatID, content string,
mediaArgs []messageMediaArg,
) ([]bus.MediaPart, error) {
if len(mediaArgs) == 0 {
return nil, nil
}
if t.mediaStore == nil {
return nil, fmt.Errorf("media store not configured")
}
if strings.TrimSpace(t.workspace) == "" {
return nil, fmt.Errorf("message media delivery is not configured")
}
scope := fmt.Sprintf("tool:message:%s:%s", channel, chatID)
parts := make([]bus.MediaPart, 0, len(mediaArgs))
for i, item := range mediaArgs {
resolved, err := fstools.ValidatePathWithAllowPaths(item.Path, t.workspace, t.restrict, t.allowPaths)
if err != nil {
return nil, fmt.Errorf("invalid media[%d].path: %w", i, err)
}
info, err := os.Stat(resolved)
if err != nil {
return nil, fmt.Errorf("media[%d] file not found: %w", i, err)
}
if info.IsDir() {
return nil, fmt.Errorf("media[%d] path is a directory, expected a file", i)
}
if t.maxFileSize > 0 && info.Size() > int64(t.maxFileSize) {
return nil, fmt.Errorf("media[%d] file too large: %d bytes (max %d bytes)", i, info.Size(), t.maxFileSize)
}
filename := item.Filename
if filename == "" {
filename = filepath.Base(resolved)
}
contentType := detectMessageMediaType(resolved)
partType := normalizeMessageMediaType(item.Type, filename, contentType)
ref, err := t.mediaStore.Store(resolved, media.MediaMeta{
Filename: filename,
ContentType: contentType,
Source: "tool:message",
CleanupPolicy: media.CleanupPolicyForgetOnly,
}, scope)
if err != nil {
return nil, fmt.Errorf("failed to register media[%d]: %w", i, err)
}
part := bus.MediaPart{
Type: partType,
Ref: ref,
Filename: filename,
ContentType: contentType,
}
if i == 0 && content != "" {
part.Caption = content
}
parts = append(parts, part)
}
return parts, nil
}
func detectMessageMediaType(path string) string {
kind, err := filetype.MatchFile(path)
if err == nil && kind != filetype.Unknown {
return kind.MIME.Value
}
if ext := filepath.Ext(path); ext != "" {
if t := mime.TypeByExtension(ext); t != "" {
return t
}
}
return "application/octet-stream"
}
func normalizeMessageMediaType(typeHint, filename, contentType string) string {
switch strings.ToLower(strings.TrimSpace(typeHint)) {
case "image", "audio", "video", "file":
return strings.ToLower(strings.TrimSpace(typeHint))
}
ct := strings.ToLower(strings.TrimSpace(contentType))
switch {
case strings.HasPrefix(ct, "image/"):
return "image"
case strings.HasPrefix(ct, "audio/"):
return "audio"
case strings.HasPrefix(ct, "video/"):
return "video"
}
switch strings.ToLower(filepath.Ext(filename)) {
case ".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp":
return "image"
case ".mp3", ".wav", ".ogg", ".oga", ".m4a", ".flac":
return "audio"
case ".mp4", ".mov", ".mkv", ".webm", ".avi":
return "video"
default:
return "file"
}
}
+105 -13
View File
@@ -3,8 +3,13 @@ package integrationtools
import (
"context"
"errors"
"os"
"path/filepath"
"regexp"
"testing"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/media"
"github.com/sipeed/picoclaw/pkg/session"
)
@@ -12,10 +17,17 @@ func TestMessageTool_Execute_Success(t *testing.T) {
tool := NewMessageTool()
var sentChannel, sentChatID, sentContent string
tool.SetSendCallback(func(ctx context.Context, channel, chatID, content, replyToMessageID string) error {
tool.SetSendCallback(func(
ctx context.Context,
channel, chatID, content, replyToMessageID string,
mediaParts []bus.MediaPart,
) error {
sentChannel = channel
sentChatID = chatID
sentContent = content
if len(mediaParts) != 0 {
t.Fatalf("expected no media parts, got %d", len(mediaParts))
}
if ToolAgentID(ctx) != "" || ToolSessionKey(ctx) != "" || ToolSessionScope(ctx) != nil {
t.Fatalf("expected empty turn metadata in basic context, got agent=%q session=%q scope=%+v",
ToolAgentID(ctx), ToolSessionKey(ctx), ToolSessionScope(ctx))
@@ -67,7 +79,11 @@ func TestMessageTool_Execute_WithCustomChannel(t *testing.T) {
tool := NewMessageTool()
var sentChannel, sentChatID string
tool.SetSendCallback(func(ctx context.Context, channel, chatID, content, replyToMessageID string) error {
tool.SetSendCallback(func(
ctx context.Context,
channel, chatID, content, replyToMessageID string,
mediaParts []bus.MediaPart,
) error {
sentChannel = channel
sentChatID = chatID
return nil
@@ -102,7 +118,11 @@ func TestMessageTool_Execute_SendFailure(t *testing.T) {
tool := NewMessageTool()
sendErr := errors.New("network error")
tool.SetSendCallback(func(ctx context.Context, channel, chatID, content, replyToMessageID string) error {
tool.SetSendCallback(func(
ctx context.Context,
channel, chatID, content, replyToMessageID string,
mediaParts []bus.MediaPart,
) error {
return sendErr
})
@@ -142,12 +162,12 @@ func TestMessageTool_Execute_MissingContent(t *testing.T) {
result := tool.Execute(ctx, args)
// Verify error result for missing content
// Verify error result for missing content/media
if !result.IsError {
t.Error("Expected IsError=true for missing content")
t.Error("Expected IsError=true for missing content/media")
}
if result.ForLLM != "content is required" {
t.Errorf("Expected ForLLM 'content is required', got '%s'", result.ForLLM)
if result.ForLLM != "content or media is required" {
t.Errorf("Expected ForLLM 'content or media is required', got '%s'", result.ForLLM)
}
}
@@ -155,7 +175,11 @@ func TestMessageTool_Execute_NoTargetChannel(t *testing.T) {
tool := NewMessageTool()
// No WithToolContext — channel/chatID are empty
tool.SetSendCallback(func(ctx context.Context, channel, chatID, content, replyToMessageID string) error {
tool.SetSendCallback(func(
ctx context.Context,
channel, chatID, content, replyToMessageID string,
mediaParts []bus.MediaPart,
) error {
return nil
})
@@ -226,9 +250,9 @@ func TestMessageTool_Parameters(t *testing.T) {
}
// Check required properties
required, ok := params["required"].([]string)
if !ok || len(required) != 1 || required[0] != "content" {
t.Error("Expected 'content' to be required")
anyOf, ok := params["anyOf"].([]map[string]any)
if !ok || len(anyOf) != 2 {
t.Fatal("Expected anyOf content/media requirement")
}
// Check content property
@@ -240,6 +264,14 @@ func TestMessageTool_Parameters(t *testing.T) {
t.Error("Expected content type to be 'string'")
}
mediaProp, ok := props["media"].(map[string]any)
if !ok {
t.Fatal("Expected 'media' property")
}
if mediaProp["type"] != "array" {
t.Error("Expected media type to be 'array'")
}
// Check channel property (optional)
channelProp, ok := props["channel"].(map[string]any)
if !ok {
@@ -272,7 +304,11 @@ func TestMessageTool_Execute_WithReplyToMessageID(t *testing.T) {
tool := NewMessageTool()
var sentReplyTo string
tool.SetSendCallback(func(ctx context.Context, channel, chatID, content, replyToMessageID string) error {
tool.SetSendCallback(func(
ctx context.Context,
channel, chatID, content, replyToMessageID string,
mediaParts []bus.MediaPart,
) error {
sentReplyTo = replyToMessageID
return nil
})
@@ -297,7 +333,11 @@ func TestMessageTool_Execute_PropagatesTurnSessionMetadata(t *testing.T) {
var gotAgentID, gotSessionKey string
var gotScope *session.SessionScope
tool.SetSendCallback(func(ctx context.Context, channel, chatID, content, replyToMessageID string) error {
tool.SetSendCallback(func(
ctx context.Context,
channel, chatID, content, replyToMessageID string,
mediaParts []bus.MediaPart,
) error {
gotAgentID = ToolAgentID(ctx)
gotSessionKey = ToolSessionKey(ctx)
gotScope = ToolSessionScope(ctx)
@@ -329,3 +369,55 @@ func TestMessageTool_Execute_PropagatesTurnSessionMetadata(t *testing.T) {
t.Fatalf("ToolSessionScope() = %+v, want chat scope", gotScope)
}
}
func TestMessageTool_Execute_WithMedia(t *testing.T) {
tool := NewMessageTool()
store := media.NewFileMediaStore()
dir := t.TempDir()
imgPath := filepath.Join(dir, "photo.jpg")
if err := os.WriteFile(imgPath, []byte("fake image bytes"), 0o644); err != nil {
t.Fatalf("write image: %v", err)
}
tool.ConfigureLocalMedia(dir, true, 1024*1024, []*regexp.Regexp{})
tool.SetMediaStore(store)
var gotContent string
var gotParts []bus.MediaPart
tool.SetSendCallback(func(
ctx context.Context,
channel, chatID, content, replyToMessageID string,
mediaParts []bus.MediaPart,
) error {
gotContent = content
gotParts = append([]bus.MediaPart(nil), mediaParts...)
return nil
})
ctx := WithToolContext(context.Background(), "telegram", "-1001")
result := tool.Execute(ctx, map[string]any{
"content": "Caption text",
"media": []any{
map[string]any{
"path": imgPath,
},
},
})
if result.IsError {
t.Fatalf("expected success, got error: %s", result.ForLLM)
}
if gotContent != "Caption text" {
t.Fatalf("content = %q, want Caption text", gotContent)
}
if len(gotParts) != 1 {
t.Fatalf("expected 1 media part, got %d", len(gotParts))
}
if gotParts[0].Caption != "Caption text" {
t.Fatalf("first part caption = %q, want Caption text", gotParts[0].Caption)
}
if gotParts[0].Ref == "" {
t.Fatal("expected media ref to be populated")
}
if gotParts[0].Type == "" {
t.Fatal("expected media type to be inferred")
}
}