mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
Merge pull request #2856 from bogdanovich/feat/message-media-outbound
feat(message): support media attachments and Telegram rich delivery
This commit is contained in:
@@ -400,6 +400,7 @@ Even with `restrict_to_workspace: false`, the `exec` tool blocks these dangerous
|
||||
|------------|------|---------|-------------|
|
||||
| `tools.allow_read_paths` | string[] | `[]` | Additional paths allowed for reading outside workspace |
|
||||
| `tools.allow_write_paths` | string[] | `[]` | Additional paths allowed for writing outside workspace |
|
||||
| `tools.message.media_enabled` | bool | `false` | Allows the `message` tool to attach local media files by path. This is separate from `tools.send_file.enabled`; enable it only when unified text/media/caption delivery is intended. |
|
||||
|
||||
### Read File Mode
|
||||
|
||||
|
||||
+36
-4
@@ -161,26 +161,58 @@ func registerSharedTools(
|
||||
// Message tool
|
||||
if cfg.Tools.IsToolEnabled("message") {
|
||||
messageTool := tools.NewMessageTool()
|
||||
if cfg.Tools.Message.MediaEnabled {
|
||||
messageTool.ConfigureLocalMedia(
|
||||
agent.Workspace,
|
||||
cfg.Agents.Defaults.RestrictToWorkspace,
|
||||
cfg.Agents.Defaults.GetMaxMediaSize(),
|
||||
allowReadPaths,
|
||||
)
|
||||
}
|
||||
messageTool.SetSendCallback(func(
|
||||
ctx context.Context,
|
||||
channel, chatID, content, replyToMessageID string,
|
||||
mediaParts []bus.MediaPart,
|
||||
) error {
|
||||
pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer pubCancel()
|
||||
outboundCtx := bus.NewOutboundContext(channel, chatID, replyToMessageID)
|
||||
outboundAgentID, outboundSessionKey, outboundScope := outboundTurnMetadata(
|
||||
tools.ToolAgentID(ctx),
|
||||
tools.ToolSessionKey(ctx),
|
||||
tools.ToolSessionScope(ctx),
|
||||
)
|
||||
return msgBus.PublishOutbound(pubCtx, bus.OutboundMessage{
|
||||
if len(mediaParts) > 0 {
|
||||
outboundMedia := bus.OutboundMediaMessage{
|
||||
Channel: channel,
|
||||
ChatID: chatID,
|
||||
Context: outboundCtx,
|
||||
AgentID: outboundAgentID,
|
||||
SessionKey: outboundSessionKey,
|
||||
Scope: outboundScope,
|
||||
Parts: mediaParts,
|
||||
}
|
||||
if al.channelManager != nil && channel != "" {
|
||||
return al.channelManager.SendMedia(ctx, outboundMedia)
|
||||
}
|
||||
pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer pubCancel()
|
||||
return msgBus.PublishOutboundMedia(pubCtx, outboundMedia)
|
||||
}
|
||||
outboundMessage := bus.OutboundMessage{
|
||||
Channel: channel,
|
||||
ChatID: chatID,
|
||||
Context: outboundCtx,
|
||||
AgentID: outboundAgentID,
|
||||
SessionKey: outboundSessionKey,
|
||||
Scope: outboundScope,
|
||||
Content: content,
|
||||
ReplyToMessageID: replyToMessageID,
|
||||
})
|
||||
}
|
||||
if al.channelManager != nil && channel != "" {
|
||||
return al.channelManager.SendMessage(ctx, outboundMessage)
|
||||
}
|
||||
pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer pubCancel()
|
||||
return msgBus.PublishOutbound(pubCtx, outboundMessage)
|
||||
})
|
||||
agent.Tools.Register(messageTool)
|
||||
}
|
||||
|
||||
@@ -377,7 +377,11 @@ func TestPublishResponseIfNeeded_DismissesToolFeedbackWhenMessageToolAlreadySent
|
||||
t.Fatal("expected default agent")
|
||||
}
|
||||
mt := tools.NewMessageTool()
|
||||
mt.SetSendCallback(func(ctx context.Context, channel, chatID, content, replyToMessageID string) error {
|
||||
mt.SetSendCallback(func(
|
||||
ctx context.Context,
|
||||
channel, chatID, content, replyToMessageID string,
|
||||
mediaParts []bus.MediaPart,
|
||||
) error {
|
||||
return nil
|
||||
})
|
||||
defaultAgent.Tools.Register(mt)
|
||||
|
||||
@@ -52,6 +52,8 @@ type FeishuChannel struct {
|
||||
|
||||
progress *channels.ToolFeedbackAnimator
|
||||
deleteMessageFn func(context.Context, string, string) error
|
||||
sendMediaPartFn func(context.Context, string, bus.MediaPart, media.MediaStore) error
|
||||
sendTextFn func(context.Context, string, string) (string, error)
|
||||
}
|
||||
|
||||
type cachedMessage struct {
|
||||
@@ -78,6 +80,8 @@ func NewFeishuChannel(bc *config.Channel, cfg *config.FeishuSettings, bus *bus.M
|
||||
client: lark.NewClient(cfg.AppID, cfg.AppSecret.String(), opts...),
|
||||
}
|
||||
ch.deleteMessageFn = ch.deleteMessageAPI
|
||||
ch.sendMediaPartFn = ch.sendMediaPart
|
||||
ch.sendTextFn = ch.sendText
|
||||
ch.progress = channels.NewToolFeedbackAnimator(ch.EditMessage)
|
||||
ch.SetOwner(ch)
|
||||
return ch, nil
|
||||
@@ -497,8 +501,16 @@ func (c *FeishuChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMess
|
||||
return nil, fmt.Errorf("no media store available: %w", channels.ErrSendFailed)
|
||||
}
|
||||
|
||||
caption := firstMediaCaption(msg.Parts)
|
||||
sentAny := false
|
||||
for _, part := range msg.Parts {
|
||||
if err := c.sendMediaPart(ctx, msg.ChatID, part, store); err != nil {
|
||||
if err := c.sendMediaPartFn(ctx, msg.ChatID, part, store); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sentAny = true
|
||||
}
|
||||
if sentAny && caption != "" {
|
||||
if _, err := c.sendTextFn(ctx, msg.ChatID, caption); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
@@ -557,6 +569,15 @@ func (c *FeishuChannel) sendMediaPart(
|
||||
return nil
|
||||
}
|
||||
|
||||
func firstMediaCaption(parts []bus.MediaPart) string {
|
||||
for _, part := range parts {
|
||||
if caption := strings.TrimSpace(part.Caption); caption != "" {
|
||||
return caption
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// --- Inbound message handling ---
|
||||
|
||||
func (c *FeishuChannel) handleMessageReceive(ctx context.Context, event *larkim.P2MessageReceiveV1) error {
|
||||
|
||||
@@ -9,7 +9,9 @@ import (
|
||||
|
||||
larkim "github.com/larksuite/oapi-sdk-go/v3/service/im/v1"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/channels"
|
||||
"github.com/sipeed/picoclaw/pkg/media"
|
||||
)
|
||||
|
||||
func TestExtractContent(t *testing.T) {
|
||||
@@ -319,6 +321,43 @@ func TestFinalizeTrackedToolFeedbackMessage_ClearAfterSuccessfulEdit(t *testing.
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendMedia_SendsCaptionFallbackAfterMedia(t *testing.T) {
|
||||
ch := &FeishuChannel{
|
||||
BaseChannel: channels.NewBaseChannel("feishu", nil, nil, nil),
|
||||
progress: channels.NewToolFeedbackAnimator(nil),
|
||||
}
|
||||
ch.SetRunning(true)
|
||||
ch.SetMediaStore(media.NewFileMediaStore())
|
||||
|
||||
var mediaOrder []string
|
||||
var textCalls []string
|
||||
ch.sendMediaPartFn = func(ctx context.Context, chatID string, part bus.MediaPart, store media.MediaStore) error {
|
||||
mediaOrder = append(mediaOrder, part.Type)
|
||||
return nil
|
||||
}
|
||||
ch.sendTextFn = func(ctx context.Context, chatID, text string) (string, error) {
|
||||
textCalls = append(textCalls, chatID+"|"+text)
|
||||
return "msg-1", nil
|
||||
}
|
||||
|
||||
_, err := ch.SendMedia(context.Background(), bus.OutboundMediaMessage{
|
||||
ChatID: "oc_123",
|
||||
Parts: []bus.MediaPart{
|
||||
{Type: "image", Caption: "shared caption"},
|
||||
{Type: "file"},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("SendMedia() error = %v", err)
|
||||
}
|
||||
if len(mediaOrder) != 2 {
|
||||
t.Fatalf("media sends = %v, want 2 sends", mediaOrder)
|
||||
}
|
||||
if len(textCalls) != 1 || textCalls[0] != "oc_123|shared caption" {
|
||||
t.Fatalf("textCalls = %v, want [oc_123|shared caption]", textCalls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFinalizeTrackedToolFeedbackMessage_StopsTrackingBeforeEdit(t *testing.T) {
|
||||
ch := &FeishuChannel{
|
||||
progress: channels.NewToolFeedbackAnimator(nil),
|
||||
|
||||
@@ -835,6 +835,75 @@ func TestSendMedia_DismissesTrackedToolFeedbackMessage(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendMedia_IncludesCaptionAndAttachmentsInSinglePayload(t *testing.T) {
|
||||
ch := newTestPicoChannel(t)
|
||||
store := media.NewFileMediaStore()
|
||||
ch.SetMediaStore(store)
|
||||
|
||||
if err := ch.Start(context.Background()); err != nil {
|
||||
t.Fatalf("Start() error = %v", err)
|
||||
}
|
||||
defer ch.Stop(context.Background())
|
||||
|
||||
clientConn, received, cleanup := newTestPicoWebSocket(t)
|
||||
defer cleanup()
|
||||
ch.addConnForTest(&picoConn{id: "conn-1", conn: clientConn, sessionID: "sess-1"})
|
||||
|
||||
localPath := filepath.Join(t.TempDir(), "photo.png")
|
||||
if err := os.WriteFile(localPath, []byte("png-body"), 0o600); err != nil {
|
||||
t.Fatalf("WriteFile() error = %v", err)
|
||||
}
|
||||
|
||||
ref, err := store.Store(localPath, media.MediaMeta{
|
||||
Filename: "photo.png",
|
||||
ContentType: "image/png",
|
||||
}, "test-scope")
|
||||
if err != nil {
|
||||
t.Fatalf("Store() error = %v", err)
|
||||
}
|
||||
|
||||
_, err = ch.SendMedia(context.Background(), bus.OutboundMediaMessage{
|
||||
ChatID: "pico:sess-1",
|
||||
Parts: []bus.MediaPart{{
|
||||
Ref: ref,
|
||||
Type: "image",
|
||||
Filename: "photo.png",
|
||||
ContentType: "image/png",
|
||||
Caption: "recipe translation",
|
||||
}},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("SendMedia() error = %v", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case msg := <-received:
|
||||
if msg.Type != TypeMessageCreate {
|
||||
t.Fatalf("message type = %q, want %q", msg.Type, TypeMessageCreate)
|
||||
}
|
||||
payload := msg.Payload
|
||||
if got := payload[PayloadKeyContent]; got != "recipe translation" {
|
||||
t.Fatalf("content = %#v, want %q", got, "recipe translation")
|
||||
}
|
||||
rawAttachments, ok := payload["attachments"].([]any)
|
||||
if !ok || len(rawAttachments) != 1 {
|
||||
t.Fatalf("attachments = %#v, want 1 attachment", payload["attachments"])
|
||||
}
|
||||
attachment, ok := rawAttachments[0].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("attachment = %#v, want map", rawAttachments[0])
|
||||
}
|
||||
if got := attachment["type"]; got != "image" {
|
||||
t.Fatalf("attachment type = %#v, want image", got)
|
||||
}
|
||||
if got := attachment["filename"]; got != "photo.png" {
|
||||
t.Fatalf("attachment filename = %#v, want photo.png", got)
|
||||
}
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("expected media payload to be delivered")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPicoDownloadURLForRef(t *testing.T) {
|
||||
got, err := picoDownloadURLForRef("media://attachment-1")
|
||||
if err != nil {
|
||||
|
||||
@@ -29,6 +29,8 @@ type SlackChannel struct {
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
pendingAcks sync.Map
|
||||
uploadFileFn func(context.Context, slack.UploadFileParameters) error
|
||||
postTextFn func(context.Context, string, string, string) error
|
||||
}
|
||||
|
||||
type slackMessageRef struct {
|
||||
@@ -63,6 +65,18 @@ func NewSlackChannel(
|
||||
config: cfg,
|
||||
api: api,
|
||||
socketClient: socketClient,
|
||||
uploadFileFn: func(ctx context.Context, params slack.UploadFileParameters) error {
|
||||
_, err := api.UploadFileContext(ctx, params)
|
||||
return err
|
||||
},
|
||||
postTextFn: func(ctx context.Context, channelID, threadTS, text string) error {
|
||||
opts := []slack.MsgOption{slack.MsgOptionText(text, false)}
|
||||
if threadTS != "" {
|
||||
opts = append(opts, slack.MsgOptionTS(threadTS))
|
||||
}
|
||||
_, _, err := api.PostMessageContext(ctx, channelID, opts...)
|
||||
return err
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -171,6 +185,8 @@ func (c *SlackChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessa
|
||||
return nil, fmt.Errorf("no media store available: %w", channels.ErrSendFailed)
|
||||
}
|
||||
|
||||
caption := slackFirstMediaCaption(msg.Parts)
|
||||
sentAny := false
|
||||
for _, part := range msg.Parts {
|
||||
localPath, err := store.Resolve(part.Ref)
|
||||
if err != nil {
|
||||
@@ -191,7 +207,7 @@ func (c *SlackChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessa
|
||||
title = filename
|
||||
}
|
||||
|
||||
_, err = c.api.UploadFileContext(ctx, slack.UploadFileParameters{
|
||||
err = c.uploadFileFn(ctx, slack.UploadFileParameters{
|
||||
Channel: channelID,
|
||||
ThreadTimestamp: threadTS,
|
||||
File: localPath,
|
||||
@@ -205,6 +221,13 @@ func (c *SlackChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessa
|
||||
})
|
||||
return nil, fmt.Errorf("slack send media: %w", channels.ErrTemporary)
|
||||
}
|
||||
sentAny = true
|
||||
}
|
||||
|
||||
if sentAny && caption != "" {
|
||||
if err := c.postTextFn(ctx, channelID, threadTS, caption); err != nil {
|
||||
return nil, fmt.Errorf("slack send media caption fallback: %w", channels.ErrTemporary)
|
||||
}
|
||||
}
|
||||
|
||||
// UploadFile does not expose the posted message timestamp in its
|
||||
@@ -212,6 +235,15 @@ func (c *SlackChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessa
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func slackFirstMediaCaption(parts []bus.MediaPart) string {
|
||||
for _, part := range parts {
|
||||
if caption := strings.TrimSpace(part.Caption); caption != "" {
|
||||
return caption
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// ReactToMessage implements channels.ReactionCapable.
|
||||
// It adds an "eyes" (👀) reaction to the inbound message and returns an undo function
|
||||
// that removes the reaction.
|
||||
|
||||
@@ -1,10 +1,17 @@
|
||||
package slack
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
slacksdk "github.com/slack-go/slack"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/channels"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/media"
|
||||
)
|
||||
|
||||
func TestParseSlackChatID(t *testing.T) {
|
||||
@@ -184,3 +191,74 @@ func TestSlackChannelIsAllowed(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestSendMedia_SendsCaptionFallbackAfterUploads(t *testing.T) {
|
||||
ch := &SlackChannel{
|
||||
BaseChannel: channels.NewBaseChannel("slack", nil, nil, nil),
|
||||
}
|
||||
ch.SetRunning(true)
|
||||
|
||||
store := media.NewFileMediaStore()
|
||||
ch.SetMediaStore(store)
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
localPath := filepath.Join(tmpDir, "report.txt")
|
||||
if err := os.WriteFile(localPath, []byte("attachment body"), 0o600); err != nil {
|
||||
t.Fatalf("WriteFile() error = %v", err)
|
||||
}
|
||||
ref, err := store.Store(localPath, media.MediaMeta{
|
||||
Filename: "report.txt",
|
||||
ContentType: "text/plain",
|
||||
}, "test-scope")
|
||||
if err != nil {
|
||||
t.Fatalf("Store() error = %v", err)
|
||||
}
|
||||
|
||||
var uploaded []slackUploadRecord
|
||||
var posted []string
|
||||
ch.uploadFileFn = func(ctx context.Context, params slacksdk.UploadFileParameters) error {
|
||||
uploaded = append(uploaded, slackUploadRecord{
|
||||
Channel: params.Channel,
|
||||
Thread: params.ThreadTimestamp,
|
||||
File: params.File,
|
||||
Name: params.Filename,
|
||||
Title: params.Title,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
ch.postTextFn = func(ctx context.Context, channelID, threadTS, text string) error {
|
||||
posted = append(posted, channelID+"|"+threadTS+"|"+text)
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err = ch.SendMedia(context.Background(), bus.OutboundMediaMessage{
|
||||
ChatID: "C123456/1234567890.123456",
|
||||
Parts: []bus.MediaPart{{
|
||||
Ref: ref,
|
||||
Type: "file",
|
||||
Filename: "report.txt",
|
||||
ContentType: "text/plain",
|
||||
Caption: "shared caption",
|
||||
}},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("SendMedia() error = %v", err)
|
||||
}
|
||||
if len(uploaded) != 1 {
|
||||
t.Fatalf("uploads = %v, want 1 upload", uploaded)
|
||||
}
|
||||
if uploaded[0].Title != "shared caption" {
|
||||
t.Fatalf("upload title = %q, want shared caption", uploaded[0].Title)
|
||||
}
|
||||
if len(posted) != 1 || posted[0] != "C123456|1234567890.123456|shared caption" {
|
||||
t.Fatalf("posted = %v, want fallback text in same thread", posted)
|
||||
}
|
||||
}
|
||||
|
||||
type slackUploadRecord struct {
|
||||
Channel string
|
||||
Thread string
|
||||
File string
|
||||
Name string
|
||||
Title string
|
||||
}
|
||||
|
||||
@@ -44,7 +44,10 @@ var (
|
||||
reInlineCode = regexp.MustCompile("`([^`]+)`")
|
||||
)
|
||||
|
||||
const defaultMediaGroupDelay = 500 * time.Millisecond
|
||||
const (
|
||||
defaultMediaGroupDelay = 500 * time.Millisecond
|
||||
telegramCaptionLimit = 1024
|
||||
)
|
||||
|
||||
type TelegramChannel struct {
|
||||
*channels.BaseChannel
|
||||
@@ -639,6 +642,34 @@ func (c *TelegramChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMe
|
||||
}
|
||||
|
||||
var messageIDs []string
|
||||
leadingCaption := telegramLeadingCaption(msg.Parts)
|
||||
if len([]rune(leadingCaption)) > telegramCaptionLimit {
|
||||
leadingIDs, leadingErr := c.sendCaptionText(ctx, chatID, threadID, leadingCaption)
|
||||
if leadingErr != nil {
|
||||
return nil, leadingErr
|
||||
}
|
||||
messageIDs = append(messageIDs, leadingIDs...)
|
||||
msg = telegramClearMediaCaptions(msg)
|
||||
}
|
||||
|
||||
if len(msg.Parts) > 1 && telegramCanSendMediaGroup(msg.Parts) {
|
||||
groupIDs, err := c.sendImageMediaGroups(ctx, chatID, threadID, store, msg.Parts)
|
||||
if err != nil {
|
||||
logger.ErrorCF("telegram", "Failed to send media group", map[string]any{
|
||||
"count": len(msg.Parts),
|
||||
"error": err.Error(),
|
||||
})
|
||||
return nil, fmt.Errorf("telegram send media group: %w", channels.ErrTemporary)
|
||||
}
|
||||
if len(groupIDs) > 0 {
|
||||
messageIDs = append(messageIDs, groupIDs...)
|
||||
if hasTrackedMsg {
|
||||
c.dismissTrackedToolFeedbackMessage(ctx, trackedChatID, trackedMsgID)
|
||||
}
|
||||
return messageIDs, nil
|
||||
}
|
||||
}
|
||||
|
||||
for _, part := range msg.Parts {
|
||||
localPath, err := store.Resolve(part.Ref)
|
||||
if err != nil {
|
||||
@@ -742,6 +773,154 @@ func (c *TelegramChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMe
|
||||
return messageIDs, nil
|
||||
}
|
||||
|
||||
func telegramCanSendMediaGroup(parts []bus.MediaPart) bool {
|
||||
if len(parts) < 2 {
|
||||
return false
|
||||
}
|
||||
for _, part := range parts {
|
||||
if part.Type != "image" {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *TelegramChannel) sendImageMediaGroups(
|
||||
ctx context.Context,
|
||||
chatID int64,
|
||||
threadID int,
|
||||
store media.MediaStore,
|
||||
parts []bus.MediaPart,
|
||||
) ([]string, error) {
|
||||
const maxGroupSize = 10
|
||||
|
||||
messageIDs := make([]string, 0, len(parts))
|
||||
for start := 0; start < len(parts); start += maxGroupSize {
|
||||
end := start + maxGroupSize
|
||||
if end > len(parts) {
|
||||
end = len(parts)
|
||||
}
|
||||
groupIDs, err := c.sendSingleImageMediaGroup(ctx, chatID, threadID, store, parts[start:end])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
messageIDs = append(messageIDs, groupIDs...)
|
||||
}
|
||||
return messageIDs, nil
|
||||
}
|
||||
|
||||
func (c *TelegramChannel) sendSingleImageMediaGroup(
|
||||
ctx context.Context,
|
||||
chatID int64,
|
||||
threadID int,
|
||||
store media.MediaStore,
|
||||
parts []bus.MediaPart,
|
||||
) ([]string, error) {
|
||||
opened := make([]*os.File, 0, len(parts))
|
||||
defer func() {
|
||||
for _, file := range opened {
|
||||
file.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
inputMedia := make([]telego.InputMedia, 0, len(parts))
|
||||
for i, part := range parts {
|
||||
localPath, err := store.Resolve(part.Ref)
|
||||
if err != nil {
|
||||
logger.ErrorCF("telegram", "Failed to resolve media ref for media group", map[string]any{
|
||||
"ref": part.Ref,
|
||||
"error": err.Error(),
|
||||
})
|
||||
return nil, err
|
||||
}
|
||||
|
||||
file, err := os.Open(localPath)
|
||||
if err != nil {
|
||||
logger.ErrorCF("telegram", "Failed to open media file for media group", map[string]any{
|
||||
"path": localPath,
|
||||
"error": err.Error(),
|
||||
})
|
||||
return nil, err
|
||||
}
|
||||
opened = append(opened, file)
|
||||
|
||||
mediaItem := &telego.InputMediaPhoto{
|
||||
Type: telego.MediaTypePhoto,
|
||||
Media: telego.InputFile{File: file},
|
||||
}
|
||||
if i == 0 {
|
||||
mediaItem.Caption = part.Caption
|
||||
}
|
||||
inputMedia = append(inputMedia, mediaItem)
|
||||
}
|
||||
|
||||
results, err := c.bot.SendMediaGroup(ctx, &telego.SendMediaGroupParams{
|
||||
ChatID: tu.ID(chatID),
|
||||
MessageThreadID: threadID,
|
||||
Media: inputMedia,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
messageIDs := make([]string, 0, len(results))
|
||||
for _, result := range results {
|
||||
messageIDs = append(messageIDs, strconv.Itoa(result.MessageID))
|
||||
}
|
||||
return messageIDs, nil
|
||||
}
|
||||
|
||||
func (c *TelegramChannel) sendCaptionText(
|
||||
ctx context.Context,
|
||||
chatID int64,
|
||||
threadID int,
|
||||
text string,
|
||||
) ([]string, error) {
|
||||
text = strings.TrimSpace(text)
|
||||
if text == "" {
|
||||
return nil, nil
|
||||
}
|
||||
chunks := channels.SplitMessage(text, c.MaxMessageLength())
|
||||
messageIDs := make([]string, 0, len(chunks))
|
||||
for _, chunk := range chunks {
|
||||
chunk = strings.TrimSpace(chunk)
|
||||
if chunk == "" {
|
||||
continue
|
||||
}
|
||||
msgID, err := c.sendChunk(ctx, sendChunkParams{
|
||||
chatID: chatID,
|
||||
threadID: threadID,
|
||||
content: chunk,
|
||||
mdFallback: chunk,
|
||||
useMarkdownV2: false,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
messageIDs = append(messageIDs, msgID)
|
||||
}
|
||||
return messageIDs, nil
|
||||
}
|
||||
|
||||
func telegramLeadingCaption(parts []bus.MediaPart) string {
|
||||
if len(parts) == 0 {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(parts[0].Caption)
|
||||
}
|
||||
|
||||
func telegramClearMediaCaptions(msg bus.OutboundMediaMessage) bus.OutboundMediaMessage {
|
||||
if len(msg.Parts) == 0 {
|
||||
return msg
|
||||
}
|
||||
cloned := msg
|
||||
cloned.Parts = append([]bus.MediaPart(nil), msg.Parts...)
|
||||
for i := range cloned.Parts {
|
||||
cloned.Parts[i].Caption = ""
|
||||
}
|
||||
return cloned
|
||||
}
|
||||
|
||||
func (c *TelegramChannel) handleMessage(ctx context.Context, message *telego.Message) error {
|
||||
if message != nil && strings.TrimSpace(message.MediaGroupID) != "" {
|
||||
return c.bufferMediaGroupMessage(ctx, message)
|
||||
|
||||
@@ -110,6 +110,17 @@ func successResponseWithMessageID(t *testing.T, messageID int) *ta.Response {
|
||||
return &ta.Response{Ok: true, Result: b}
|
||||
}
|
||||
|
||||
func successMediaGroupResponse(t *testing.T, messageIDs ...int) *ta.Response {
|
||||
t.Helper()
|
||||
messages := make([]telego.Message, 0, len(messageIDs))
|
||||
for _, messageID := range messageIDs {
|
||||
messages = append(messages, telego.Message{MessageID: messageID})
|
||||
}
|
||||
b, err := json.Marshal(messages)
|
||||
require.NoError(t, err)
|
||||
return &ta.Response{Ok: true, Result: b}
|
||||
}
|
||||
|
||||
func successUserResponse(t *testing.T, user *telego.User) *ta.Response {
|
||||
t.Helper()
|
||||
b, err := json.Marshal(user)
|
||||
@@ -237,6 +248,276 @@ func TestSendMedia_ImageNonDimensionErrorDoesNotFallback(t *testing.T) {
|
||||
assert.NotContains(t, caller.calls[0].URL, "sendDocument")
|
||||
}
|
||||
|
||||
func TestSendMedia_MultipleImagesUseMediaGroup(t *testing.T) {
|
||||
constructor := &multipartRecordingConstructor{}
|
||||
caller := &stubCaller{
|
||||
callFn: func(ctx context.Context, url string, data *ta.RequestData) (*ta.Response, error) {
|
||||
if strings.Contains(url, "sendMediaGroup") {
|
||||
return successMediaGroupResponse(t, 101, 102), nil
|
||||
}
|
||||
t.Fatalf("unexpected API call: %s", url)
|
||||
return nil, nil
|
||||
},
|
||||
}
|
||||
ch := newTestChannelWithConstructor(t, caller, constructor)
|
||||
|
||||
store := media.NewFileMediaStore()
|
||||
ch.SetMediaStore(store)
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
firstPath := filepath.Join(tmpDir, "first.png")
|
||||
secondPath := filepath.Join(tmpDir, "second.png")
|
||||
require.NoError(t, os.WriteFile(firstPath, []byte("first-image"), 0o644))
|
||||
require.NoError(t, os.WriteFile(secondPath, []byte("second-image"), 0o644))
|
||||
|
||||
firstRef, err := store.Store(firstPath, media.MediaMeta{Filename: "first.png", ContentType: "image/png"}, "scope-1")
|
||||
require.NoError(t, err)
|
||||
secondRef, err := store.Store(
|
||||
secondPath,
|
||||
media.MediaMeta{Filename: "second.png", ContentType: "image/png"},
|
||||
"scope-1",
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
ids, err := ch.SendMedia(context.Background(), bus.OutboundMediaMessage{
|
||||
ChatID: "12345",
|
||||
Parts: []bus.MediaPart{
|
||||
{Type: "image", Ref: firstRef, Caption: "album caption"},
|
||||
{Type: "image", Ref: secondRef},
|
||||
},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{"101", "102"}, ids)
|
||||
require.Len(t, caller.calls, 1)
|
||||
assert.Contains(t, caller.calls[0].URL, "sendMediaGroup")
|
||||
require.Len(t, constructor.calls, 1)
|
||||
require.Len(t, constructor.calls[0].FileSizes, 2)
|
||||
|
||||
var mediaPayload []map[string]any
|
||||
require.NoError(t, json.Unmarshal([]byte(constructor.calls[0].Parameters["media"]), &mediaPayload))
|
||||
require.Len(t, mediaPayload, 2)
|
||||
assert.Equal(t, "album caption", mediaPayload[0]["caption"])
|
||||
_, hasSecondCaption := mediaPayload[1]["caption"]
|
||||
assert.False(t, hasSecondCaption)
|
||||
}
|
||||
|
||||
func TestSendMedia_MoreThanTenImagesSplitIntoMediaGroups(t *testing.T) {
|
||||
constructor := &multipartRecordingConstructor{}
|
||||
callIndex := 0
|
||||
caller := &stubCaller{
|
||||
callFn: func(ctx context.Context, url string, data *ta.RequestData) (*ta.Response, error) {
|
||||
if !strings.Contains(url, "sendMediaGroup") {
|
||||
t.Fatalf("unexpected API call: %s", url)
|
||||
}
|
||||
callIndex++
|
||||
if callIndex == 1 {
|
||||
return successMediaGroupResponse(t, 1001, 1002, 1003, 1004, 1005, 1006, 1007, 1008, 1009, 1010), nil
|
||||
}
|
||||
if callIndex == 2 {
|
||||
return successMediaGroupResponse(t, 1011, 1012, 1013, 1014, 1015), nil
|
||||
}
|
||||
t.Fatalf("unexpected sendMediaGroup call #%d", callIndex)
|
||||
return nil, nil
|
||||
},
|
||||
}
|
||||
ch := newTestChannelWithConstructor(t, caller, constructor)
|
||||
|
||||
store := media.NewFileMediaStore()
|
||||
ch.SetMediaStore(store)
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
parts := make([]bus.MediaPart, 0, 15)
|
||||
for i := 0; i < 15; i++ {
|
||||
path := filepath.Join(tmpDir, "image-"+strconv.Itoa(i)+".png")
|
||||
require.NoError(t, os.WriteFile(path, []byte("img-"+strconv.Itoa(i)), 0o644))
|
||||
ref, err := store.Store(
|
||||
path,
|
||||
media.MediaMeta{Filename: filepath.Base(path), ContentType: "image/png"},
|
||||
"scope-1",
|
||||
)
|
||||
require.NoError(t, err)
|
||||
part := bus.MediaPart{Type: "image", Ref: ref}
|
||||
if i == 0 {
|
||||
part.Caption = "long album caption"
|
||||
}
|
||||
parts = append(parts, part)
|
||||
}
|
||||
|
||||
ids, err := ch.SendMedia(context.Background(), bus.OutboundMediaMessage{
|
||||
ChatID: "12345",
|
||||
Parts: parts,
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{
|
||||
"1001", "1002", "1003", "1004", "1005",
|
||||
"1006", "1007", "1008", "1009", "1010",
|
||||
"1011", "1012", "1013", "1014", "1015",
|
||||
}, ids)
|
||||
require.Len(t, caller.calls, 2)
|
||||
require.Len(t, constructor.calls, 2)
|
||||
}
|
||||
|
||||
func TestSendMedia_SingleImageLongCaptionSendsTextFirst(t *testing.T) {
|
||||
constructor := &multipartRecordingConstructor{}
|
||||
longCaption := strings.Repeat("a", telegramCaptionLimit) + " tail overflow"
|
||||
caller := &stubCaller{
|
||||
callFn: func(ctx context.Context, url string, data *ta.RequestData) (*ta.Response, error) {
|
||||
switch {
|
||||
case strings.Contains(url, "sendMessage"):
|
||||
return successResponseWithMessageID(t, 201), nil
|
||||
case strings.Contains(url, "sendPhoto"):
|
||||
return successResponseWithMessageID(t, 202), nil
|
||||
default:
|
||||
t.Fatalf("unexpected API call: %s", url)
|
||||
return nil, nil
|
||||
}
|
||||
},
|
||||
}
|
||||
ch := newTestChannelWithConstructor(t, caller, constructor)
|
||||
|
||||
store := media.NewFileMediaStore()
|
||||
ch.SetMediaStore(store)
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
path := filepath.Join(tmpDir, "image.png")
|
||||
require.NoError(t, os.WriteFile(path, []byte("img"), 0o644))
|
||||
ref, err := store.Store(path, media.MediaMeta{Filename: "image.png", ContentType: "image/png"}, "scope-1")
|
||||
require.NoError(t, err)
|
||||
|
||||
ids, err := ch.SendMedia(context.Background(), bus.OutboundMediaMessage{
|
||||
ChatID: "12345",
|
||||
Parts: []bus.MediaPart{{
|
||||
Type: "image",
|
||||
Ref: ref,
|
||||
Caption: longCaption,
|
||||
}},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{"201", "202"}, ids)
|
||||
require.Len(t, caller.calls, 2)
|
||||
assert.Contains(t, caller.calls[0].URL, "sendMessage")
|
||||
assert.Contains(t, caller.calls[1].URL, "sendPhoto")
|
||||
assert.Equal(t, "", constructor.calls[0].Parameters["caption"])
|
||||
}
|
||||
|
||||
func TestSendMedia_MediaGroupLongCaptionSendsTextFirst(t *testing.T) {
|
||||
constructor := &multipartRecordingConstructor{}
|
||||
longCaption := strings.Repeat("b", telegramCaptionLimit) + " trailing explanation"
|
||||
caller := &stubCaller{
|
||||
callFn: func(ctx context.Context, url string, data *ta.RequestData) (*ta.Response, error) {
|
||||
switch {
|
||||
case strings.Contains(url, "sendMessage"):
|
||||
return successResponseWithMessageID(t, 301), nil
|
||||
case strings.Contains(url, "sendMediaGroup"):
|
||||
return successMediaGroupResponse(t, 302, 303), nil
|
||||
default:
|
||||
t.Fatalf("unexpected API call: %s", url)
|
||||
return nil, nil
|
||||
}
|
||||
},
|
||||
}
|
||||
ch := newTestChannelWithConstructor(t, caller, constructor)
|
||||
|
||||
store := media.NewFileMediaStore()
|
||||
ch.SetMediaStore(store)
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
firstPath := filepath.Join(tmpDir, "first.png")
|
||||
secondPath := filepath.Join(tmpDir, "second.png")
|
||||
require.NoError(t, os.WriteFile(firstPath, []byte("first-image"), 0o644))
|
||||
require.NoError(t, os.WriteFile(secondPath, []byte("second-image"), 0o644))
|
||||
|
||||
firstRef, err := store.Store(firstPath, media.MediaMeta{Filename: "first.png", ContentType: "image/png"}, "scope-1")
|
||||
require.NoError(t, err)
|
||||
secondRef, err := store.Store(
|
||||
secondPath,
|
||||
media.MediaMeta{Filename: "second.png", ContentType: "image/png"},
|
||||
"scope-1",
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
ids, err := ch.SendMedia(context.Background(), bus.OutboundMediaMessage{
|
||||
ChatID: "12345",
|
||||
Parts: []bus.MediaPart{
|
||||
{Type: "image", Ref: firstRef, Caption: longCaption},
|
||||
{Type: "image", Ref: secondRef},
|
||||
},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{"301", "302", "303"}, ids)
|
||||
require.Len(t, caller.calls, 2)
|
||||
assert.Contains(t, caller.calls[0].URL, "sendMessage")
|
||||
assert.Contains(t, caller.calls[1].URL, "sendMediaGroup")
|
||||
}
|
||||
|
||||
func TestSendMedia_MultiGroupLongCaptionSendsTextBeforeGroups(t *testing.T) {
|
||||
constructor := &multipartRecordingConstructor{}
|
||||
longCaption := strings.Repeat("c", telegramCaptionLimit) + " overflow before second album"
|
||||
callOrder := make([]string, 0, 3)
|
||||
caller := &stubCaller{
|
||||
callFn: func(ctx context.Context, url string, data *ta.RequestData) (*ta.Response, error) {
|
||||
switch {
|
||||
case strings.Contains(url, "sendMessage"):
|
||||
callOrder = append(callOrder, "text")
|
||||
return successResponseWithMessageID(t, 499), nil
|
||||
case strings.Contains(url, "sendMediaGroup"):
|
||||
callOrder = append(callOrder, "group")
|
||||
if len(callOrder) == 2 {
|
||||
return successMediaGroupResponse(t, 401, 402, 403, 404, 405, 406, 407, 408, 409, 410), nil
|
||||
}
|
||||
if len(callOrder) == 3 {
|
||||
return successMediaGroupResponse(t, 411, 412, 413, 414, 415), nil
|
||||
}
|
||||
t.Fatalf("unexpected sendMediaGroup order: %v", callOrder)
|
||||
return nil, nil
|
||||
default:
|
||||
t.Fatalf("unexpected API call: %s", url)
|
||||
return nil, nil
|
||||
}
|
||||
},
|
||||
}
|
||||
ch := newTestChannelWithConstructor(t, caller, constructor)
|
||||
|
||||
store := media.NewFileMediaStore()
|
||||
ch.SetMediaStore(store)
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
parts := make([]bus.MediaPart, 0, 15)
|
||||
for i := 0; i < 15; i++ {
|
||||
path := filepath.Join(tmpDir, "image-"+strconv.Itoa(i)+".png")
|
||||
require.NoError(t, os.WriteFile(path, []byte("img-"+strconv.Itoa(i)), 0o644))
|
||||
ref, err := store.Store(
|
||||
path,
|
||||
media.MediaMeta{Filename: filepath.Base(path), ContentType: "image/png"},
|
||||
"scope-1",
|
||||
)
|
||||
require.NoError(t, err)
|
||||
part := bus.MediaPart{Type: "image", Ref: ref}
|
||||
if i == 0 {
|
||||
part.Caption = longCaption
|
||||
}
|
||||
parts = append(parts, part)
|
||||
}
|
||||
|
||||
ids, err := ch.SendMedia(context.Background(), bus.OutboundMediaMessage{
|
||||
ChatID: "12345",
|
||||
Parts: parts,
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{
|
||||
"499",
|
||||
"401", "402", "403", "404", "405",
|
||||
"406", "407", "408", "409", "410",
|
||||
"411", "412", "413", "414", "415",
|
||||
}, ids)
|
||||
assert.Equal(t, []string{"text", "group", "group"}, callOrder)
|
||||
}
|
||||
|
||||
func TestSend_EmptyContent(t *testing.T) {
|
||||
caller := &stubCaller{
|
||||
callFn: func(ctx context.Context, url string, data *ta.RequestData) (*ta.Response, error) {
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -319,3 +320,61 @@ func TestSelectInboundMediaItemFallsBackToRefMessage(t *testing.T) {
|
||||
t.Fatalf("selectInboundMediaItem().Type = %d, want %d", item.Type, MessageItemTypeImage)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendUploadedMedia_SendsCaptionAsSeparateTextBeforeMedia(t *testing.T) {
|
||||
var requests []SendMessageReq
|
||||
ch := &WeixinChannel{
|
||||
api: &ApiClient{
|
||||
BaseURL: "https://ilinkai.weixin.qq.com/",
|
||||
HttpClient: &http.Client{Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
if r.URL.Path != "/ilink/bot/sendmessage" {
|
||||
t.Fatalf("sendmessage path = %q, want /ilink/bot/sendmessage", r.URL.Path)
|
||||
}
|
||||
var req SendMessageReq
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
t.Fatalf("decode sendmessage req: %v", err)
|
||||
}
|
||||
requests = append(requests, req)
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(bytes.NewReader([]byte(`{"ret":0,"errcode":0}`))),
|
||||
Header: make(http.Header),
|
||||
}, nil
|
||||
})},
|
||||
},
|
||||
typingCache: make(map[string]typingTicketCacheEntry),
|
||||
}
|
||||
|
||||
err := ch.sendUploadedMedia(
|
||||
context.Background(),
|
||||
"user-1",
|
||||
"ctx-1",
|
||||
"recipe translation",
|
||||
UploadMediaTypeImage,
|
||||
&uploadedFileInfo{
|
||||
downloadParam: "download-token",
|
||||
aesKeyHex: "31323334353637383930616263646566",
|
||||
fileSize: 11,
|
||||
cipherSize: 16,
|
||||
filename: "photo.png",
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("sendUploadedMedia() error = %v", err)
|
||||
}
|
||||
if len(requests) != 2 {
|
||||
t.Fatalf("sendUploadedMedia() sent %d requests, want 2", len(requests))
|
||||
}
|
||||
if len(requests[0].Msg.ItemList) != 1 || requests[0].Msg.ItemList[0].Type != MessageItemTypeText {
|
||||
t.Fatalf("first request item = %+v, want text item", requests[0].Msg.ItemList)
|
||||
}
|
||||
if got := requests[0].Msg.ItemList[0].TextItem.Text; got != "recipe translation" {
|
||||
t.Fatalf("first request text = %q, want recipe translation", got)
|
||||
}
|
||||
if len(requests[1].Msg.ItemList) != 1 || requests[1].Msg.ItemList[0].Type != MessageItemTypeImage {
|
||||
t.Fatalf("second request item = %+v, want image item", requests[1].Msg.ItemList)
|
||||
}
|
||||
if requests[1].Msg.ItemList[0].ImageItem == nil || requests[1].Msg.ItemList[0].ImageItem.Media == nil {
|
||||
t.Fatalf("second request image media = %+v, want media ref", requests[1].Msg.ItemList[0].ImageItem)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -814,6 +814,12 @@ type ToolConfig struct {
|
||||
Enabled bool `json:"enabled" yaml:"-" env:"ENABLED"`
|
||||
}
|
||||
|
||||
type MessageToolsConfig struct {
|
||||
ToolConfig `yaml:"-" envPrefix:"PICOCLAW_TOOLS_MESSAGE_"`
|
||||
|
||||
MediaEnabled bool `json:"media_enabled" yaml:"-" env:"PICOCLAW_TOOLS_MESSAGE_MEDIA_ENABLED"`
|
||||
}
|
||||
|
||||
type BraveConfig struct {
|
||||
Enabled bool `json:"enabled" yaml:"-" env:"PICOCLAW_TOOLS_WEB_BRAVE_ENABLED"`
|
||||
APIKeys SecureStrings `json:"api_keys,omitzero" yaml:"api_keys,omitempty" env:"PICOCLAW_TOOLS_WEB_BRAVE_API_KEYS"`
|
||||
@@ -1026,7 +1032,7 @@ type ToolsConfig struct {
|
||||
InstallSkill ToolConfig `json:"install_skill" yaml:"-" envPrefix:"PICOCLAW_TOOLS_INSTALL_SKILL_"`
|
||||
ListDir ToolConfig `json:"list_dir" yaml:"-" envPrefix:"PICOCLAW_TOOLS_LIST_DIR_"`
|
||||
LoadImage ToolConfig `json:"load_image" yaml:"-" envPrefix:"PICOCLAW_TOOLS_LOAD_IMAGE_"`
|
||||
Message ToolConfig `json:"message" yaml:"-" envPrefix:"PICOCLAW_TOOLS_MESSAGE_"`
|
||||
Message MessageToolsConfig `json:"message" yaml:"-"`
|
||||
ReadFile ReadFileToolConfig `json:"read_file" yaml:"-" envPrefix:"PICOCLAW_TOOLS_READ_FILE_"`
|
||||
Serial ToolConfig `json:"serial" yaml:"-" envPrefix:"PICOCLAW_TOOLS_SERIAL_"`
|
||||
SendFile ToolConfig `json:"send_file" yaml:"-" envPrefix:"PICOCLAW_TOOLS_SEND_FILE_"`
|
||||
|
||||
@@ -1480,6 +1480,16 @@ func TestLoadConfig_LoadImageCanBeDisabled(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultConfig_MessageMediaDisabled(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
if !cfg.Tools.Message.Enabled {
|
||||
t.Fatal("DefaultConfig().Tools.Message.Enabled should be true")
|
||||
}
|
||||
if cfg.Tools.Message.MediaEnabled {
|
||||
t.Fatal("DefaultConfig().Tools.Message.MediaEnabled should be false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolsConfig_GetFilterMinLength(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
@@ -447,8 +447,11 @@ func DefaultConfig() *Config {
|
||||
LoadImage: ToolConfig{
|
||||
Enabled: true,
|
||||
},
|
||||
Message: ToolConfig{
|
||||
Enabled: true,
|
||||
Message: MessageToolsConfig{
|
||||
ToolConfig: ToolConfig{
|
||||
Enabled: true,
|
||||
},
|
||||
MediaEnabled: false,
|
||||
},
|
||||
ReadFile: ReadFileToolConfig{
|
||||
Enabled: true,
|
||||
|
||||
@@ -3,10 +3,32 @@ package integrationtools
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"mime"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/h2non/filetype"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/media"
|
||||
fstools "github.com/sipeed/picoclaw/pkg/tools/fs"
|
||||
)
|
||||
|
||||
type SendCallbackWithContext func(ctx context.Context, channel, chatID, content, replyToMessageID string) error
|
||||
type SendCallbackWithContext func(
|
||||
ctx context.Context,
|
||||
channel, chatID, content, replyToMessageID string,
|
||||
mediaParts []bus.MediaPart,
|
||||
) error
|
||||
|
||||
type messageMediaArg struct {
|
||||
Path string
|
||||
Type string
|
||||
Filename string
|
||||
}
|
||||
|
||||
// sentTarget records the channel+chatID that the message tool sent to.
|
||||
type sentTarget struct {
|
||||
@@ -15,11 +37,15 @@ type sentTarget struct {
|
||||
}
|
||||
|
||||
type MessageTool struct {
|
||||
sendCallback SendCallbackWithContext
|
||||
mu sync.Mutex
|
||||
// sentTargets tracks targets sent to in the current round, keyed by session key
|
||||
// to support parallel turns for different sessions.
|
||||
sentTargets map[string][]sentTarget
|
||||
sendCallback SendCallbackWithContext
|
||||
workspace string
|
||||
restrict bool
|
||||
maxFileSize int
|
||||
mediaStore media.MediaStore
|
||||
allowPaths []*regexp.Regexp
|
||||
localMediaEnabled bool
|
||||
mu sync.Mutex
|
||||
sentTargets map[string][]sentTarget
|
||||
}
|
||||
|
||||
func NewMessageTool() *MessageTool {
|
||||
@@ -33,32 +59,86 @@ func (t *MessageTool) Name() string {
|
||||
}
|
||||
|
||||
func (t *MessageTool) Description() string {
|
||||
return "Send a message to user on a chat channel. Use this when you want to communicate something."
|
||||
if !t.localMediaEnabled {
|
||||
return "Send a text message to the user on a chat channel."
|
||||
}
|
||||
return "Send a message to the user on a chat channel. Supports text-only, media-only, or text with media attachments."
|
||||
}
|
||||
|
||||
func (t *MessageTool) Parameters() map[string]any {
|
||||
return map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"content": map[string]any{
|
||||
"type": "string",
|
||||
"description": "The message content to send",
|
||||
},
|
||||
"channel": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Optional: target channel (telegram, whatsapp, etc.)",
|
||||
},
|
||||
"chat_id": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Optional: target chat/user ID",
|
||||
},
|
||||
"reply_to_message_id": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Optional: reply target message ID for channels that support threaded replies",
|
||||
},
|
||||
properties := map[string]any{
|
||||
"content": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Optional message text. When media is present, this text is used as the caption/body for the media message.",
|
||||
},
|
||||
"channel": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Optional: target channel (telegram, whatsapp, etc.)",
|
||||
},
|
||||
"chat_id": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Optional: target chat/user ID",
|
||||
},
|
||||
"reply_to_message_id": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Optional: reply target message ID for channels that support threaded replies",
|
||||
},
|
||||
"required": []string{"content"},
|
||||
}
|
||||
params := map[string]any{
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
"required": []string{"content"},
|
||||
}
|
||||
if t.localMediaEnabled {
|
||||
properties["media"] = map[string]any{
|
||||
"type": "array",
|
||||
"description": "Optional local media attachments to send with the message. Requires tools.message.media_enabled.",
|
||||
"items": map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"path": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Path to the local file. Relative paths are resolved from workspace.",
|
||||
},
|
||||
"type": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Optional media type hint: image, audio, video, or file.",
|
||||
},
|
||||
"filename": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Optional display filename. Defaults to the basename of path.",
|
||||
},
|
||||
},
|
||||
"required": []string{"path"},
|
||||
},
|
||||
}
|
||||
delete(params, "required")
|
||||
params["anyOf"] = []map[string]any{
|
||||
{"required": []string{"content"}},
|
||||
{"required": []string{"media"}},
|
||||
}
|
||||
}
|
||||
return params
|
||||
}
|
||||
|
||||
func (t *MessageTool) ConfigureLocalMedia(
|
||||
workspace string,
|
||||
restrict bool,
|
||||
maxFileSize int,
|
||||
allowPaths []*regexp.Regexp,
|
||||
) {
|
||||
t.workspace = workspace
|
||||
t.restrict = restrict
|
||||
if maxFileSize <= 0 {
|
||||
maxFileSize = config.DefaultMaxMediaSize
|
||||
}
|
||||
t.maxFileSize = maxFileSize
|
||||
t.allowPaths = allowPaths
|
||||
t.localMediaEnabled = true
|
||||
}
|
||||
|
||||
func (t *MessageTool) SetMediaStore(store media.MediaStore) {
|
||||
t.mediaStore = store
|
||||
}
|
||||
|
||||
// ResetSentInRound resets the per-round send tracker for the given session key.
|
||||
@@ -98,9 +178,20 @@ func (t *MessageTool) SetSendCallback(callback SendCallbackWithContext) {
|
||||
}
|
||||
|
||||
func (t *MessageTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
|
||||
content, ok := args["content"].(string)
|
||||
if !ok {
|
||||
return &ToolResult{ForLLM: "content is required", IsError: true}
|
||||
content, _ := args["content"].(string)
|
||||
content = strings.TrimSpace(content)
|
||||
mediaArgs, err := parseMessageMediaArgs(args["media"])
|
||||
if err != nil {
|
||||
return &ToolResult{ForLLM: err.Error(), IsError: true}
|
||||
}
|
||||
if len(mediaArgs) > 0 && !t.localMediaEnabled {
|
||||
return &ToolResult{
|
||||
ForLLM: "message media attachments are disabled; enable tools.message.media_enabled to send local media through message",
|
||||
IsError: true,
|
||||
}
|
||||
}
|
||||
if content == "" && len(mediaArgs) == 0 {
|
||||
return &ToolResult{ForLLM: "content or media is required", IsError: true}
|
||||
}
|
||||
|
||||
channel, _ := args["channel"].(string)
|
||||
@@ -122,7 +213,12 @@ func (t *MessageTool) Execute(ctx context.Context, args map[string]any) *ToolRes
|
||||
return &ToolResult{ForLLM: "Message sending not configured", IsError: true}
|
||||
}
|
||||
|
||||
if err := t.sendCallback(ctx, channel, chatID, content, replyToMessageID); err != nil {
|
||||
parts, err := t.buildMediaParts(channel, chatID, content, mediaArgs)
|
||||
if err != nil {
|
||||
return &ToolResult{ForLLM: err.Error(), IsError: true, Err: err}
|
||||
}
|
||||
|
||||
if err := t.sendCallback(ctx, channel, chatID, content, replyToMessageID, parts); err != nil {
|
||||
return &ToolResult{
|
||||
ForLLM: fmt.Sprintf("sending message: %v", err),
|
||||
IsError: true,
|
||||
@@ -135,9 +231,149 @@ func (t *MessageTool) Execute(ctx context.Context, args map[string]any) *ToolRes
|
||||
t.sentTargets[sessionKey] = append(t.sentTargets[sessionKey], sentTarget{Channel: channel, ChatID: chatID})
|
||||
t.mu.Unlock()
|
||||
|
||||
// Silent: user already received the message directly
|
||||
status := fmt.Sprintf("Message sent to %s:%s", channel, chatID)
|
||||
if len(parts) > 0 {
|
||||
status = fmt.Sprintf("Message with %d media attachment(s) sent to %s:%s", len(parts), channel, chatID)
|
||||
}
|
||||
|
||||
return &ToolResult{
|
||||
ForLLM: fmt.Sprintf("Message sent to %s:%s", channel, chatID),
|
||||
ForLLM: status,
|
||||
Silent: true,
|
||||
}
|
||||
}
|
||||
|
||||
func parseMessageMediaArgs(raw any) ([]messageMediaArg, error) {
|
||||
if raw == nil {
|
||||
return nil, nil
|
||||
}
|
||||
items, ok := raw.([]any)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("media must be an array")
|
||||
}
|
||||
result := make([]messageMediaArg, 0, len(items))
|
||||
for i, item := range items {
|
||||
obj, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("media[%d] must be an object", i)
|
||||
}
|
||||
path, _ := obj["path"].(string)
|
||||
path = strings.TrimSpace(path)
|
||||
if path == "" {
|
||||
return nil, fmt.Errorf("media[%d].path is required", i)
|
||||
}
|
||||
typ, _ := obj["type"].(string)
|
||||
filename, _ := obj["filename"].(string)
|
||||
result = append(result, messageMediaArg{
|
||||
Path: path,
|
||||
Type: strings.TrimSpace(typ),
|
||||
Filename: strings.TrimSpace(filename),
|
||||
})
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (t *MessageTool) buildMediaParts(
|
||||
channel, chatID, content string,
|
||||
mediaArgs []messageMediaArg,
|
||||
) ([]bus.MediaPart, error) {
|
||||
if len(mediaArgs) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
if !t.localMediaEnabled {
|
||||
return nil, fmt.Errorf("message media attachments are disabled")
|
||||
}
|
||||
if t.mediaStore == nil {
|
||||
return nil, fmt.Errorf("media store not configured")
|
||||
}
|
||||
if strings.TrimSpace(t.workspace) == "" {
|
||||
return nil, fmt.Errorf("message media delivery is not configured")
|
||||
}
|
||||
|
||||
scope := fmt.Sprintf("tool:message:%s:%s", channel, chatID)
|
||||
parts := make([]bus.MediaPart, 0, len(mediaArgs))
|
||||
for i, item := range mediaArgs {
|
||||
resolved, err := fstools.ValidatePathWithAllowPaths(item.Path, t.workspace, t.restrict, t.allowPaths)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid media[%d].path: %w", i, err)
|
||||
}
|
||||
info, err := os.Stat(resolved)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("media[%d] file not found: %w", i, err)
|
||||
}
|
||||
if info.IsDir() {
|
||||
return nil, fmt.Errorf("media[%d] path is a directory, expected a file", i)
|
||||
}
|
||||
if t.maxFileSize > 0 && info.Size() > int64(t.maxFileSize) {
|
||||
return nil, fmt.Errorf("media[%d] file too large: %d bytes (max %d bytes)", i, info.Size(), t.maxFileSize)
|
||||
}
|
||||
|
||||
filename := item.Filename
|
||||
if filename == "" {
|
||||
filename = filepath.Base(resolved)
|
||||
}
|
||||
contentType := detectMessageMediaType(resolved)
|
||||
partType := normalizeMessageMediaType(item.Type, filename, contentType)
|
||||
ref, err := t.mediaStore.Store(resolved, media.MediaMeta{
|
||||
Filename: filename,
|
||||
ContentType: contentType,
|
||||
Source: "tool:message",
|
||||
CleanupPolicy: media.CleanupPolicyForgetOnly,
|
||||
}, scope)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to register media[%d]: %w", i, err)
|
||||
}
|
||||
|
||||
part := bus.MediaPart{
|
||||
Type: partType,
|
||||
Ref: ref,
|
||||
Filename: filename,
|
||||
ContentType: contentType,
|
||||
}
|
||||
if i == 0 && content != "" {
|
||||
part.Caption = content
|
||||
}
|
||||
parts = append(parts, part)
|
||||
}
|
||||
return parts, nil
|
||||
}
|
||||
|
||||
func detectMessageMediaType(path string) string {
|
||||
kind, err := filetype.MatchFile(path)
|
||||
if err == nil && kind != filetype.Unknown {
|
||||
return kind.MIME.Value
|
||||
}
|
||||
if ext := filepath.Ext(path); ext != "" {
|
||||
if t := mime.TypeByExtension(ext); t != "" {
|
||||
return t
|
||||
}
|
||||
}
|
||||
return "application/octet-stream"
|
||||
}
|
||||
|
||||
func normalizeMessageMediaType(typeHint, filename, contentType string) string {
|
||||
switch strings.ToLower(strings.TrimSpace(typeHint)) {
|
||||
case "image", "audio", "video", "file":
|
||||
return strings.ToLower(strings.TrimSpace(typeHint))
|
||||
}
|
||||
|
||||
ct := strings.ToLower(strings.TrimSpace(contentType))
|
||||
switch {
|
||||
case strings.HasPrefix(ct, "image/"):
|
||||
return "image"
|
||||
case strings.HasPrefix(ct, "audio/"):
|
||||
return "audio"
|
||||
case strings.HasPrefix(ct, "video/"):
|
||||
return "video"
|
||||
}
|
||||
|
||||
switch strings.ToLower(filepath.Ext(filename)) {
|
||||
case ".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp":
|
||||
return "image"
|
||||
case ".mp3", ".wav", ".ogg", ".oga", ".m4a", ".flac":
|
||||
return "audio"
|
||||
case ".mp4", ".mov", ".mkv", ".webm", ".avi":
|
||||
return "video"
|
||||
default:
|
||||
return "file"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,8 +3,13 @@ package integrationtools
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/media"
|
||||
"github.com/sipeed/picoclaw/pkg/session"
|
||||
)
|
||||
|
||||
@@ -12,10 +17,17 @@ func TestMessageTool_Execute_Success(t *testing.T) {
|
||||
tool := NewMessageTool()
|
||||
|
||||
var sentChannel, sentChatID, sentContent string
|
||||
tool.SetSendCallback(func(ctx context.Context, channel, chatID, content, replyToMessageID string) error {
|
||||
tool.SetSendCallback(func(
|
||||
ctx context.Context,
|
||||
channel, chatID, content, replyToMessageID string,
|
||||
mediaParts []bus.MediaPart,
|
||||
) error {
|
||||
sentChannel = channel
|
||||
sentChatID = chatID
|
||||
sentContent = content
|
||||
if len(mediaParts) != 0 {
|
||||
t.Fatalf("expected no media parts, got %d", len(mediaParts))
|
||||
}
|
||||
if ToolAgentID(ctx) != "" || ToolSessionKey(ctx) != "" || ToolSessionScope(ctx) != nil {
|
||||
t.Fatalf("expected empty turn metadata in basic context, got agent=%q session=%q scope=%+v",
|
||||
ToolAgentID(ctx), ToolSessionKey(ctx), ToolSessionScope(ctx))
|
||||
@@ -67,7 +79,11 @@ func TestMessageTool_Execute_WithCustomChannel(t *testing.T) {
|
||||
tool := NewMessageTool()
|
||||
|
||||
var sentChannel, sentChatID string
|
||||
tool.SetSendCallback(func(ctx context.Context, channel, chatID, content, replyToMessageID string) error {
|
||||
tool.SetSendCallback(func(
|
||||
ctx context.Context,
|
||||
channel, chatID, content, replyToMessageID string,
|
||||
mediaParts []bus.MediaPart,
|
||||
) error {
|
||||
sentChannel = channel
|
||||
sentChatID = chatID
|
||||
return nil
|
||||
@@ -102,7 +118,11 @@ func TestMessageTool_Execute_SendFailure(t *testing.T) {
|
||||
tool := NewMessageTool()
|
||||
|
||||
sendErr := errors.New("network error")
|
||||
tool.SetSendCallback(func(ctx context.Context, channel, chatID, content, replyToMessageID string) error {
|
||||
tool.SetSendCallback(func(
|
||||
ctx context.Context,
|
||||
channel, chatID, content, replyToMessageID string,
|
||||
mediaParts []bus.MediaPart,
|
||||
) error {
|
||||
return sendErr
|
||||
})
|
||||
|
||||
@@ -142,12 +162,12 @@ func TestMessageTool_Execute_MissingContent(t *testing.T) {
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Verify error result for missing content
|
||||
// Verify error result for missing content/media
|
||||
if !result.IsError {
|
||||
t.Error("Expected IsError=true for missing content")
|
||||
t.Error("Expected IsError=true for missing content/media")
|
||||
}
|
||||
if result.ForLLM != "content is required" {
|
||||
t.Errorf("Expected ForLLM 'content is required', got '%s'", result.ForLLM)
|
||||
if result.ForLLM != "content or media is required" {
|
||||
t.Errorf("Expected ForLLM 'content or media is required', got '%s'", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -155,7 +175,11 @@ func TestMessageTool_Execute_NoTargetChannel(t *testing.T) {
|
||||
tool := NewMessageTool()
|
||||
// No WithToolContext — channel/chatID are empty
|
||||
|
||||
tool.SetSendCallback(func(ctx context.Context, channel, chatID, content, replyToMessageID string) error {
|
||||
tool.SetSendCallback(func(
|
||||
ctx context.Context,
|
||||
channel, chatID, content, replyToMessageID string,
|
||||
mediaParts []bus.MediaPart,
|
||||
) error {
|
||||
return nil
|
||||
})
|
||||
|
||||
@@ -228,7 +252,7 @@ func TestMessageTool_Parameters(t *testing.T) {
|
||||
// Check required properties
|
||||
required, ok := params["required"].([]string)
|
||||
if !ok || len(required) != 1 || required[0] != "content" {
|
||||
t.Error("Expected 'content' to be required")
|
||||
t.Fatal("Expected content-only required schema when local media is disabled")
|
||||
}
|
||||
|
||||
// Check content property
|
||||
@@ -240,6 +264,10 @@ func TestMessageTool_Parameters(t *testing.T) {
|
||||
t.Error("Expected content type to be 'string'")
|
||||
}
|
||||
|
||||
if _, hasMedia := props["media"]; hasMedia {
|
||||
t.Fatal("did not expect 'media' property when local media is disabled")
|
||||
}
|
||||
|
||||
// Check channel property (optional)
|
||||
channelProp, ok := props["channel"].(map[string]any)
|
||||
if !ok {
|
||||
@@ -268,11 +296,65 @@ func TestMessageTool_Parameters(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessageTool_Parameters_WithLocalMediaEnabled(t *testing.T) {
|
||||
tool := NewMessageTool()
|
||||
tool.ConfigureLocalMedia(t.TempDir(), true, 1024*1024, nil)
|
||||
params := tool.Parameters()
|
||||
|
||||
props, ok := params["properties"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("Expected properties to be a map")
|
||||
}
|
||||
mediaProp, ok := props["media"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("Expected 'media' property")
|
||||
}
|
||||
if mediaProp["type"] != "array" {
|
||||
t.Error("Expected media type to be 'array'")
|
||||
}
|
||||
anyOf, ok := params["anyOf"].([]map[string]any)
|
||||
if !ok || len(anyOf) != 2 {
|
||||
t.Fatal("Expected anyOf content/media requirement")
|
||||
}
|
||||
if _, ok := params["required"]; ok {
|
||||
t.Fatal("did not expect top-level required content when media is enabled")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessageTool_Execute_WithMediaDisabled(t *testing.T) {
|
||||
tool := NewMessageTool()
|
||||
tool.SetSendCallback(func(
|
||||
ctx context.Context,
|
||||
channel, chatID, content, replyToMessageID string,
|
||||
mediaParts []bus.MediaPart,
|
||||
) error {
|
||||
t.Fatal("send callback should not run when message media is disabled")
|
||||
return nil
|
||||
})
|
||||
|
||||
ctx := WithToolContext(context.Background(), "telegram", "-1001")
|
||||
result := tool.Execute(ctx, map[string]any{
|
||||
"media": []any{
|
||||
map[string]any{"path": "photo.jpg"},
|
||||
},
|
||||
})
|
||||
if !result.IsError {
|
||||
t.Fatal("expected error when message media is disabled")
|
||||
}
|
||||
if result.ForLLM != "message media attachments are disabled; enable tools.message.media_enabled to send local media through message" {
|
||||
t.Fatalf("unexpected error: %q", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessageTool_Execute_WithReplyToMessageID(t *testing.T) {
|
||||
tool := NewMessageTool()
|
||||
|
||||
var sentReplyTo string
|
||||
tool.SetSendCallback(func(ctx context.Context, channel, chatID, content, replyToMessageID string) error {
|
||||
tool.SetSendCallback(func(
|
||||
ctx context.Context,
|
||||
channel, chatID, content, replyToMessageID string,
|
||||
mediaParts []bus.MediaPart,
|
||||
) error {
|
||||
sentReplyTo = replyToMessageID
|
||||
return nil
|
||||
})
|
||||
@@ -297,7 +379,11 @@ func TestMessageTool_Execute_PropagatesTurnSessionMetadata(t *testing.T) {
|
||||
|
||||
var gotAgentID, gotSessionKey string
|
||||
var gotScope *session.SessionScope
|
||||
tool.SetSendCallback(func(ctx context.Context, channel, chatID, content, replyToMessageID string) error {
|
||||
tool.SetSendCallback(func(
|
||||
ctx context.Context,
|
||||
channel, chatID, content, replyToMessageID string,
|
||||
mediaParts []bus.MediaPart,
|
||||
) error {
|
||||
gotAgentID = ToolAgentID(ctx)
|
||||
gotSessionKey = ToolSessionKey(ctx)
|
||||
gotScope = ToolSessionScope(ctx)
|
||||
@@ -329,3 +415,55 @@ func TestMessageTool_Execute_PropagatesTurnSessionMetadata(t *testing.T) {
|
||||
t.Fatalf("ToolSessionScope() = %+v, want chat scope", gotScope)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessageTool_Execute_WithMedia(t *testing.T) {
|
||||
tool := NewMessageTool()
|
||||
store := media.NewFileMediaStore()
|
||||
dir := t.TempDir()
|
||||
imgPath := filepath.Join(dir, "photo.jpg")
|
||||
if err := os.WriteFile(imgPath, []byte("fake image bytes"), 0o644); err != nil {
|
||||
t.Fatalf("write image: %v", err)
|
||||
}
|
||||
tool.ConfigureLocalMedia(dir, true, 1024*1024, []*regexp.Regexp{})
|
||||
tool.SetMediaStore(store)
|
||||
|
||||
var gotContent string
|
||||
var gotParts []bus.MediaPart
|
||||
tool.SetSendCallback(func(
|
||||
ctx context.Context,
|
||||
channel, chatID, content, replyToMessageID string,
|
||||
mediaParts []bus.MediaPart,
|
||||
) error {
|
||||
gotContent = content
|
||||
gotParts = append([]bus.MediaPart(nil), mediaParts...)
|
||||
return nil
|
||||
})
|
||||
|
||||
ctx := WithToolContext(context.Background(), "telegram", "-1001")
|
||||
result := tool.Execute(ctx, map[string]any{
|
||||
"content": "Caption text",
|
||||
"media": []any{
|
||||
map[string]any{
|
||||
"path": imgPath,
|
||||
},
|
||||
},
|
||||
})
|
||||
if result.IsError {
|
||||
t.Fatalf("expected success, got error: %s", result.ForLLM)
|
||||
}
|
||||
if gotContent != "Caption text" {
|
||||
t.Fatalf("content = %q, want Caption text", gotContent)
|
||||
}
|
||||
if len(gotParts) != 1 {
|
||||
t.Fatalf("expected 1 media part, got %d", len(gotParts))
|
||||
}
|
||||
if gotParts[0].Caption != "Caption text" {
|
||||
t.Fatalf("first part caption = %q, want Caption text", gotParts[0].Caption)
|
||||
}
|
||||
if gotParts[0].Ref == "" {
|
||||
t.Fatal("expected media ref to be populated")
|
||||
}
|
||||
if gotParts[0].Type == "" {
|
||||
t.Fatal("expected media type to be inferred")
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user