mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
feat(message): support media attachments in outbound tool
This commit is contained in:
@@ -3,10 +3,32 @@ package integrationtools
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"mime"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/h2non/filetype"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/media"
|
||||
fstools "github.com/sipeed/picoclaw/pkg/tools/fs"
|
||||
)
|
||||
|
||||
type SendCallbackWithContext func(ctx context.Context, channel, chatID, content, replyToMessageID string) error
|
||||
type SendCallbackWithContext func(
|
||||
ctx context.Context,
|
||||
channel, chatID, content, replyToMessageID string,
|
||||
mediaParts []bus.MediaPart,
|
||||
) error
|
||||
|
||||
type messageMediaArg struct {
|
||||
Path string
|
||||
Type string
|
||||
Filename string
|
||||
}
|
||||
|
||||
// sentTarget records the channel+chatID that the message tool sent to.
|
||||
type sentTarget struct {
|
||||
@@ -16,10 +38,13 @@ type sentTarget struct {
|
||||
|
||||
type MessageTool struct {
|
||||
sendCallback SendCallbackWithContext
|
||||
workspace string
|
||||
restrict bool
|
||||
maxFileSize int
|
||||
mediaStore media.MediaStore
|
||||
allowPaths []*regexp.Regexp
|
||||
mu sync.Mutex
|
||||
// sentTargets tracks targets sent to in the current round, keyed by session key
|
||||
// to support parallel turns for different sessions.
|
||||
sentTargets map[string][]sentTarget
|
||||
sentTargets map[string][]sentTarget
|
||||
}
|
||||
|
||||
func NewMessageTool() *MessageTool {
|
||||
@@ -33,7 +58,7 @@ func (t *MessageTool) Name() string {
|
||||
}
|
||||
|
||||
func (t *MessageTool) Description() string {
|
||||
return "Send a message to user on a chat channel. Use this when you want to communicate something."
|
||||
return "Send a message to the user on a chat channel. Supports text-only, media-only, or text with media attachments."
|
||||
}
|
||||
|
||||
func (t *MessageTool) Parameters() map[string]any {
|
||||
@@ -42,7 +67,29 @@ func (t *MessageTool) Parameters() map[string]any {
|
||||
"properties": map[string]any{
|
||||
"content": map[string]any{
|
||||
"type": "string",
|
||||
"description": "The message content to send",
|
||||
"description": "Optional message text. When media is present, this text is used as the caption/body for the media message.",
|
||||
},
|
||||
"media": map[string]any{
|
||||
"type": "array",
|
||||
"description": "Optional local media attachments to send with the message.",
|
||||
"items": map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"path": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Path to the local file. Relative paths are resolved from workspace.",
|
||||
},
|
||||
"type": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Optional media type hint: image, audio, video, or file.",
|
||||
},
|
||||
"filename": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Optional display filename. Defaults to the basename of path.",
|
||||
},
|
||||
},
|
||||
"required": []string{"path"},
|
||||
},
|
||||
},
|
||||
"channel": map[string]any{
|
||||
"type": "string",
|
||||
@@ -57,10 +104,32 @@ func (t *MessageTool) Parameters() map[string]any {
|
||||
"description": "Optional: reply target message ID for channels that support threaded replies",
|
||||
},
|
||||
},
|
||||
"required": []string{"content"},
|
||||
"anyOf": []map[string]any{
|
||||
{"required": []string{"content"}},
|
||||
{"required": []string{"media"}},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (t *MessageTool) ConfigureLocalMedia(
|
||||
workspace string,
|
||||
restrict bool,
|
||||
maxFileSize int,
|
||||
allowPaths []*regexp.Regexp,
|
||||
) {
|
||||
t.workspace = workspace
|
||||
t.restrict = restrict
|
||||
if maxFileSize <= 0 {
|
||||
maxFileSize = config.DefaultMaxMediaSize
|
||||
}
|
||||
t.maxFileSize = maxFileSize
|
||||
t.allowPaths = allowPaths
|
||||
}
|
||||
|
||||
func (t *MessageTool) SetMediaStore(store media.MediaStore) {
|
||||
t.mediaStore = store
|
||||
}
|
||||
|
||||
// ResetSentInRound resets the per-round send tracker for the given session key.
|
||||
// Called by the agent loop at the start of each inbound message processing round.
|
||||
func (t *MessageTool) ResetSentInRound(sessionKey string) {
|
||||
@@ -98,9 +167,14 @@ func (t *MessageTool) SetSendCallback(callback SendCallbackWithContext) {
|
||||
}
|
||||
|
||||
func (t *MessageTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
|
||||
content, ok := args["content"].(string)
|
||||
if !ok {
|
||||
return &ToolResult{ForLLM: "content is required", IsError: true}
|
||||
content, _ := args["content"].(string)
|
||||
content = strings.TrimSpace(content)
|
||||
mediaArgs, err := parseMessageMediaArgs(args["media"])
|
||||
if err != nil {
|
||||
return &ToolResult{ForLLM: err.Error(), IsError: true}
|
||||
}
|
||||
if content == "" && len(mediaArgs) == 0 {
|
||||
return &ToolResult{ForLLM: "content or media is required", IsError: true}
|
||||
}
|
||||
|
||||
channel, _ := args["channel"].(string)
|
||||
@@ -122,7 +196,12 @@ func (t *MessageTool) Execute(ctx context.Context, args map[string]any) *ToolRes
|
||||
return &ToolResult{ForLLM: "Message sending not configured", IsError: true}
|
||||
}
|
||||
|
||||
if err := t.sendCallback(ctx, channel, chatID, content, replyToMessageID); err != nil {
|
||||
parts, err := t.buildMediaParts(channel, chatID, content, mediaArgs)
|
||||
if err != nil {
|
||||
return &ToolResult{ForLLM: err.Error(), IsError: true, Err: err}
|
||||
}
|
||||
|
||||
if err := t.sendCallback(ctx, channel, chatID, content, replyToMessageID, parts); err != nil {
|
||||
return &ToolResult{
|
||||
ForLLM: fmt.Sprintf("sending message: %v", err),
|
||||
IsError: true,
|
||||
@@ -135,9 +214,146 @@ func (t *MessageTool) Execute(ctx context.Context, args map[string]any) *ToolRes
|
||||
t.sentTargets[sessionKey] = append(t.sentTargets[sessionKey], sentTarget{Channel: channel, ChatID: chatID})
|
||||
t.mu.Unlock()
|
||||
|
||||
// Silent: user already received the message directly
|
||||
status := fmt.Sprintf("Message sent to %s:%s", channel, chatID)
|
||||
if len(parts) > 0 {
|
||||
status = fmt.Sprintf("Message with %d media attachment(s) sent to %s:%s", len(parts), channel, chatID)
|
||||
}
|
||||
|
||||
return &ToolResult{
|
||||
ForLLM: fmt.Sprintf("Message sent to %s:%s", channel, chatID),
|
||||
ForLLM: status,
|
||||
Silent: true,
|
||||
}
|
||||
}
|
||||
|
||||
func parseMessageMediaArgs(raw any) ([]messageMediaArg, error) {
|
||||
if raw == nil {
|
||||
return nil, nil
|
||||
}
|
||||
items, ok := raw.([]any)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("media must be an array")
|
||||
}
|
||||
result := make([]messageMediaArg, 0, len(items))
|
||||
for i, item := range items {
|
||||
obj, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("media[%d] must be an object", i)
|
||||
}
|
||||
path, _ := obj["path"].(string)
|
||||
path = strings.TrimSpace(path)
|
||||
if path == "" {
|
||||
return nil, fmt.Errorf("media[%d].path is required", i)
|
||||
}
|
||||
typ, _ := obj["type"].(string)
|
||||
filename, _ := obj["filename"].(string)
|
||||
result = append(result, messageMediaArg{
|
||||
Path: path,
|
||||
Type: strings.TrimSpace(typ),
|
||||
Filename: strings.TrimSpace(filename),
|
||||
})
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (t *MessageTool) buildMediaParts(
|
||||
channel, chatID, content string,
|
||||
mediaArgs []messageMediaArg,
|
||||
) ([]bus.MediaPart, error) {
|
||||
if len(mediaArgs) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
if t.mediaStore == nil {
|
||||
return nil, fmt.Errorf("media store not configured")
|
||||
}
|
||||
if strings.TrimSpace(t.workspace) == "" {
|
||||
return nil, fmt.Errorf("message media delivery is not configured")
|
||||
}
|
||||
|
||||
scope := fmt.Sprintf("tool:message:%s:%s", channel, chatID)
|
||||
parts := make([]bus.MediaPart, 0, len(mediaArgs))
|
||||
for i, item := range mediaArgs {
|
||||
resolved, err := fstools.ValidatePathWithAllowPaths(item.Path, t.workspace, t.restrict, t.allowPaths)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid media[%d].path: %w", i, err)
|
||||
}
|
||||
info, err := os.Stat(resolved)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("media[%d] file not found: %w", i, err)
|
||||
}
|
||||
if info.IsDir() {
|
||||
return nil, fmt.Errorf("media[%d] path is a directory, expected a file", i)
|
||||
}
|
||||
if t.maxFileSize > 0 && info.Size() > int64(t.maxFileSize) {
|
||||
return nil, fmt.Errorf("media[%d] file too large: %d bytes (max %d bytes)", i, info.Size(), t.maxFileSize)
|
||||
}
|
||||
|
||||
filename := item.Filename
|
||||
if filename == "" {
|
||||
filename = filepath.Base(resolved)
|
||||
}
|
||||
contentType := detectMessageMediaType(resolved)
|
||||
partType := normalizeMessageMediaType(item.Type, filename, contentType)
|
||||
ref, err := t.mediaStore.Store(resolved, media.MediaMeta{
|
||||
Filename: filename,
|
||||
ContentType: contentType,
|
||||
Source: "tool:message",
|
||||
CleanupPolicy: media.CleanupPolicyForgetOnly,
|
||||
}, scope)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to register media[%d]: %w", i, err)
|
||||
}
|
||||
|
||||
part := bus.MediaPart{
|
||||
Type: partType,
|
||||
Ref: ref,
|
||||
Filename: filename,
|
||||
ContentType: contentType,
|
||||
}
|
||||
if i == 0 && content != "" {
|
||||
part.Caption = content
|
||||
}
|
||||
parts = append(parts, part)
|
||||
}
|
||||
return parts, nil
|
||||
}
|
||||
|
||||
func detectMessageMediaType(path string) string {
|
||||
kind, err := filetype.MatchFile(path)
|
||||
if err == nil && kind != filetype.Unknown {
|
||||
return kind.MIME.Value
|
||||
}
|
||||
if ext := filepath.Ext(path); ext != "" {
|
||||
if t := mime.TypeByExtension(ext); t != "" {
|
||||
return t
|
||||
}
|
||||
}
|
||||
return "application/octet-stream"
|
||||
}
|
||||
|
||||
func normalizeMessageMediaType(typeHint, filename, contentType string) string {
|
||||
switch strings.ToLower(strings.TrimSpace(typeHint)) {
|
||||
case "image", "audio", "video", "file":
|
||||
return strings.ToLower(strings.TrimSpace(typeHint))
|
||||
}
|
||||
|
||||
ct := strings.ToLower(strings.TrimSpace(contentType))
|
||||
switch {
|
||||
case strings.HasPrefix(ct, "image/"):
|
||||
return "image"
|
||||
case strings.HasPrefix(ct, "audio/"):
|
||||
return "audio"
|
||||
case strings.HasPrefix(ct, "video/"):
|
||||
return "video"
|
||||
}
|
||||
|
||||
switch strings.ToLower(filepath.Ext(filename)) {
|
||||
case ".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp":
|
||||
return "image"
|
||||
case ".mp3", ".wav", ".ogg", ".oga", ".m4a", ".flac":
|
||||
return "audio"
|
||||
case ".mp4", ".mov", ".mkv", ".webm", ".avi":
|
||||
return "video"
|
||||
default:
|
||||
return "file"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,8 +3,13 @@ package integrationtools
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/media"
|
||||
"github.com/sipeed/picoclaw/pkg/session"
|
||||
)
|
||||
|
||||
@@ -12,10 +17,17 @@ func TestMessageTool_Execute_Success(t *testing.T) {
|
||||
tool := NewMessageTool()
|
||||
|
||||
var sentChannel, sentChatID, sentContent string
|
||||
tool.SetSendCallback(func(ctx context.Context, channel, chatID, content, replyToMessageID string) error {
|
||||
tool.SetSendCallback(func(
|
||||
ctx context.Context,
|
||||
channel, chatID, content, replyToMessageID string,
|
||||
mediaParts []bus.MediaPart,
|
||||
) error {
|
||||
sentChannel = channel
|
||||
sentChatID = chatID
|
||||
sentContent = content
|
||||
if len(mediaParts) != 0 {
|
||||
t.Fatalf("expected no media parts, got %d", len(mediaParts))
|
||||
}
|
||||
if ToolAgentID(ctx) != "" || ToolSessionKey(ctx) != "" || ToolSessionScope(ctx) != nil {
|
||||
t.Fatalf("expected empty turn metadata in basic context, got agent=%q session=%q scope=%+v",
|
||||
ToolAgentID(ctx), ToolSessionKey(ctx), ToolSessionScope(ctx))
|
||||
@@ -67,7 +79,11 @@ func TestMessageTool_Execute_WithCustomChannel(t *testing.T) {
|
||||
tool := NewMessageTool()
|
||||
|
||||
var sentChannel, sentChatID string
|
||||
tool.SetSendCallback(func(ctx context.Context, channel, chatID, content, replyToMessageID string) error {
|
||||
tool.SetSendCallback(func(
|
||||
ctx context.Context,
|
||||
channel, chatID, content, replyToMessageID string,
|
||||
mediaParts []bus.MediaPart,
|
||||
) error {
|
||||
sentChannel = channel
|
||||
sentChatID = chatID
|
||||
return nil
|
||||
@@ -102,7 +118,11 @@ func TestMessageTool_Execute_SendFailure(t *testing.T) {
|
||||
tool := NewMessageTool()
|
||||
|
||||
sendErr := errors.New("network error")
|
||||
tool.SetSendCallback(func(ctx context.Context, channel, chatID, content, replyToMessageID string) error {
|
||||
tool.SetSendCallback(func(
|
||||
ctx context.Context,
|
||||
channel, chatID, content, replyToMessageID string,
|
||||
mediaParts []bus.MediaPart,
|
||||
) error {
|
||||
return sendErr
|
||||
})
|
||||
|
||||
@@ -142,12 +162,12 @@ func TestMessageTool_Execute_MissingContent(t *testing.T) {
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Verify error result for missing content
|
||||
// Verify error result for missing content/media
|
||||
if !result.IsError {
|
||||
t.Error("Expected IsError=true for missing content")
|
||||
t.Error("Expected IsError=true for missing content/media")
|
||||
}
|
||||
if result.ForLLM != "content is required" {
|
||||
t.Errorf("Expected ForLLM 'content is required', got '%s'", result.ForLLM)
|
||||
if result.ForLLM != "content or media is required" {
|
||||
t.Errorf("Expected ForLLM 'content or media is required', got '%s'", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -155,7 +175,11 @@ func TestMessageTool_Execute_NoTargetChannel(t *testing.T) {
|
||||
tool := NewMessageTool()
|
||||
// No WithToolContext — channel/chatID are empty
|
||||
|
||||
tool.SetSendCallback(func(ctx context.Context, channel, chatID, content, replyToMessageID string) error {
|
||||
tool.SetSendCallback(func(
|
||||
ctx context.Context,
|
||||
channel, chatID, content, replyToMessageID string,
|
||||
mediaParts []bus.MediaPart,
|
||||
) error {
|
||||
return nil
|
||||
})
|
||||
|
||||
@@ -226,9 +250,9 @@ func TestMessageTool_Parameters(t *testing.T) {
|
||||
}
|
||||
|
||||
// Check required properties
|
||||
required, ok := params["required"].([]string)
|
||||
if !ok || len(required) != 1 || required[0] != "content" {
|
||||
t.Error("Expected 'content' to be required")
|
||||
anyOf, ok := params["anyOf"].([]map[string]any)
|
||||
if !ok || len(anyOf) != 2 {
|
||||
t.Fatal("Expected anyOf content/media requirement")
|
||||
}
|
||||
|
||||
// Check content property
|
||||
@@ -240,6 +264,14 @@ func TestMessageTool_Parameters(t *testing.T) {
|
||||
t.Error("Expected content type to be 'string'")
|
||||
}
|
||||
|
||||
mediaProp, ok := props["media"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("Expected 'media' property")
|
||||
}
|
||||
if mediaProp["type"] != "array" {
|
||||
t.Error("Expected media type to be 'array'")
|
||||
}
|
||||
|
||||
// Check channel property (optional)
|
||||
channelProp, ok := props["channel"].(map[string]any)
|
||||
if !ok {
|
||||
@@ -272,7 +304,11 @@ func TestMessageTool_Execute_WithReplyToMessageID(t *testing.T) {
|
||||
tool := NewMessageTool()
|
||||
|
||||
var sentReplyTo string
|
||||
tool.SetSendCallback(func(ctx context.Context, channel, chatID, content, replyToMessageID string) error {
|
||||
tool.SetSendCallback(func(
|
||||
ctx context.Context,
|
||||
channel, chatID, content, replyToMessageID string,
|
||||
mediaParts []bus.MediaPart,
|
||||
) error {
|
||||
sentReplyTo = replyToMessageID
|
||||
return nil
|
||||
})
|
||||
@@ -297,7 +333,11 @@ func TestMessageTool_Execute_PropagatesTurnSessionMetadata(t *testing.T) {
|
||||
|
||||
var gotAgentID, gotSessionKey string
|
||||
var gotScope *session.SessionScope
|
||||
tool.SetSendCallback(func(ctx context.Context, channel, chatID, content, replyToMessageID string) error {
|
||||
tool.SetSendCallback(func(
|
||||
ctx context.Context,
|
||||
channel, chatID, content, replyToMessageID string,
|
||||
mediaParts []bus.MediaPart,
|
||||
) error {
|
||||
gotAgentID = ToolAgentID(ctx)
|
||||
gotSessionKey = ToolSessionKey(ctx)
|
||||
gotScope = ToolSessionScope(ctx)
|
||||
@@ -329,3 +369,55 @@ func TestMessageTool_Execute_PropagatesTurnSessionMetadata(t *testing.T) {
|
||||
t.Fatalf("ToolSessionScope() = %+v, want chat scope", gotScope)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessageTool_Execute_WithMedia(t *testing.T) {
|
||||
tool := NewMessageTool()
|
||||
store := media.NewFileMediaStore()
|
||||
dir := t.TempDir()
|
||||
imgPath := filepath.Join(dir, "photo.jpg")
|
||||
if err := os.WriteFile(imgPath, []byte("fake image bytes"), 0o644); err != nil {
|
||||
t.Fatalf("write image: %v", err)
|
||||
}
|
||||
tool.ConfigureLocalMedia(dir, true, 1024*1024, []*regexp.Regexp{})
|
||||
tool.SetMediaStore(store)
|
||||
|
||||
var gotContent string
|
||||
var gotParts []bus.MediaPart
|
||||
tool.SetSendCallback(func(
|
||||
ctx context.Context,
|
||||
channel, chatID, content, replyToMessageID string,
|
||||
mediaParts []bus.MediaPart,
|
||||
) error {
|
||||
gotContent = content
|
||||
gotParts = append([]bus.MediaPart(nil), mediaParts...)
|
||||
return nil
|
||||
})
|
||||
|
||||
ctx := WithToolContext(context.Background(), "telegram", "-1001")
|
||||
result := tool.Execute(ctx, map[string]any{
|
||||
"content": "Caption text",
|
||||
"media": []any{
|
||||
map[string]any{
|
||||
"path": imgPath,
|
||||
},
|
||||
},
|
||||
})
|
||||
if result.IsError {
|
||||
t.Fatalf("expected success, got error: %s", result.ForLLM)
|
||||
}
|
||||
if gotContent != "Caption text" {
|
||||
t.Fatalf("content = %q, want Caption text", gotContent)
|
||||
}
|
||||
if len(gotParts) != 1 {
|
||||
t.Fatalf("expected 1 media part, got %d", len(gotParts))
|
||||
}
|
||||
if gotParts[0].Caption != "Caption text" {
|
||||
t.Fatalf("first part caption = %q, want Caption text", gotParts[0].Caption)
|
||||
}
|
||||
if gotParts[0].Ref == "" {
|
||||
t.Fatal("expected media ref to be populated")
|
||||
}
|
||||
if gotParts[0].Type == "" {
|
||||
t.Fatal("expected media type to be inferred")
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user