From 28f69e71cc2bcc2f15319f06f3d4fe8833301671 Mon Sep 17 00:00:00 2001 From: reusu Date: Sat, 28 Mar 2026 22:49:54 +0800 Subject: [PATCH] fix: address load_image PR review feedback - Exclude load_image from sub-agent tools via Unregister after Clone, since RunToolLoop does not call resolveMediaRefs - Add ToolRegistry.Unregister() method - Fix scope collision: use channel:chatID instead of filename - Add channel/chatID context resolution matching send_file pattern - Add comment explaining iteration > 1 guard on resolveMediaRefs - Remove emoji from ForUser for consistency with send_file - Add load_image_test.go --- pkg/agent/loop.go | 11 +++++- pkg/tools/load_image.go | 55 ++++++++++++++++++-------- pkg/tools/load_image_test.go | 77 ++++++++++++++++++++++++++++++++++++ pkg/tools/registry.go | 12 ++++++ 4 files changed, 137 insertions(+), 18 deletions(-) create mode 100644 pkg/tools/load_image_test.go diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index 819eeb4e6..ea3a37f7a 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -363,7 +363,12 @@ func registerSharedTools( // tools registered so far (file, web, etc.) but NOT spawn/ // spawn_status which are added below — preventing recursive // subagent spawning. - subagentManager.SetTools(agent.Tools.Clone()) + subagentTools := agent.Tools.Clone() + // load_image depends on resolveMediaRefs which only runs in + // the main agent loop, not in RunToolLoop. Remove it from + // sub-agent tools so the LLM won't call it in vain. + subagentTools.Unregister("load_image") + subagentManager.SetTools(subagentTools) if spawnEnabled { spawnTool := tools.NewSpawnTool(subagentManager) spawnTool.SetSpawner(NewSubTurnSpawner(al)) @@ -1817,6 +1822,10 @@ turnLoop: providerToolDefs = filtered } + // Resolve media:// refs produced by tool results (e.g. load_image). + // Skipped on iteration 1 because inbound user media is already resolved + // before entering the loop; only subsequent iterations can contain new + // tool-generated media refs that need base64 encoding. if iteration > 1 { messages = resolveMediaRefs(messages, al.mediaStore, maxMediaSize) } diff --git a/pkg/tools/load_image.go b/pkg/tools/load_image.go index bd386346c..4bad2e0c0 100644 --- a/pkg/tools/load_image.go +++ b/pkg/tools/load_image.go @@ -22,33 +22,36 @@ import ( // - SendFileTool → MediaResult + WithResponseHandled() → sends file to user, ends turn // - LoadImageTool → plain ToolResult with media:// in ForLLM → LLM sees the image next turn type LoadImageTool struct { - workspace string - restrict bool - maxSize int - mediaStore media.MediaStore - allowPaths []*regexp.Regexp + workspace string + restrict bool + maxFileSize int + mediaStore media.MediaStore + allowPaths []*regexp.Regexp + + defaultChannel string + defaultChatID string } func NewLoadImageTool( workspace string, restrict bool, - maxSize int, + maxFileSize int, store media.MediaStore, allowPaths ...[]*regexp.Regexp, ) *LoadImageTool { - if maxSize <= 0 { - maxSize = config.DefaultMaxMediaSize + if maxFileSize <= 0 { + maxFileSize = config.DefaultMaxMediaSize } var patterns []*regexp.Regexp if len(allowPaths) > 0 { patterns = allowPaths[0] } return &LoadImageTool{ - workspace: workspace, - restrict: restrict, - maxSize: maxSize, - mediaStore: store, - allowPaths: patterns, + workspace: workspace, + restrict: restrict, + maxFileSize: maxFileSize, + mediaStore: store, + allowPaths: patterns, } } @@ -77,6 +80,11 @@ func (t *LoadImageTool) Parameters() map[string]any { } } +func (t *LoadImageTool) SetContext(channel, chatID string) { + t.defaultChannel = channel + t.defaultChatID = chatID +} + func (t *LoadImageTool) SetMediaStore(store media.MediaStore) { t.mediaStore = store } @@ -87,6 +95,19 @@ func (t *LoadImageTool) Execute(ctx context.Context, args map[string]any) *ToolR return ErrorResult("path is required") } + // Prefer context-injected channel/chatID (set by ExecuteWithContext), fall back to SetContext values. + channel := ToolChannel(ctx) + if channel == "" { + channel = t.defaultChannel + } + chatID := ToolChatID(ctx) + if chatID == "" { + chatID = t.defaultChatID + } + if channel == "" || chatID == "" { + return ErrorResult("no target channel/chat available") + } + if t.mediaStore == nil { return ErrorResult("media store not configured") } @@ -103,9 +124,9 @@ func (t *LoadImageTool) Execute(ctx context.Context, args map[string]any) *ToolR if info.IsDir() { return ErrorResult("path is a directory, expected an image file") } - if info.Size() > int64(t.maxSize) { + if info.Size() > int64(t.maxFileSize) { return ErrorResult(fmt.Sprintf( - "file too large: %d bytes (max %d bytes)", info.Size(), t.maxSize, + "file too large: %d bytes (max %d bytes)", info.Size(), t.maxFileSize, )) } @@ -118,7 +139,7 @@ func (t *LoadImageTool) Execute(ctx context.Context, args map[string]any) *ToolR } filename := filepath.Base(resolved) - scope := fmt.Sprintf("tool:load_image:%s", filename) + scope := fmt.Sprintf("tool:load_image:%s:%s", channel, chatID) ref, err := t.mediaStore.Store(resolved, media.MediaMeta{ Filename: filename, @@ -143,7 +164,7 @@ func (t *LoadImageTool) Execute(ctx context.Context, args map[string]any) *ToolR return &ToolResult{ ForLLM: msg, - ForUser: fmt.Sprintf("📷 Loaded image: %s", filename), + ForUser: fmt.Sprintf("Loaded image: %s", filename), // Media refs inside ForLLM are resolved by resolveMediaRefs in the // agent loop before the next LLM call. Do NOT use MediaResult here — // that would send the file to the user channel instead. diff --git a/pkg/tools/load_image_test.go b/pkg/tools/load_image_test.go new file mode 100644 index 000000000..33820bba7 --- /dev/null +++ b/pkg/tools/load_image_test.go @@ -0,0 +1,77 @@ +package tools + +import ( + "context" + "os" + "path/filepath" + "testing" + + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/media" +) + +func TestLoadImage_PathRequired(t *testing.T) { + tool := NewLoadImageTool("/tmp", false, 0, nil) + ctx := WithToolContext(context.Background(), "test", "chat1") + result := tool.Execute(ctx, map[string]any{}) + if !result.IsError { + t.Fatal("expected error for missing path") + } +} + +func TestLoadImage_NilMediaStore(t *testing.T) { + tool := NewLoadImageTool("/tmp", false, 0, nil) + ctx := WithToolContext(context.Background(), "test", "chat1") + result := tool.Execute(ctx, map[string]any{"path": "test.png"}) + if !result.IsError || result.ForLLM != "media store not configured" { + t.Fatalf("expected media store error, got: %s", result.ForLLM) + } +} + +func TestLoadImage_NoChannelContext(t *testing.T) { + store := media.NewFileMediaStore() + tool := NewLoadImageTool("/tmp", false, 0, store) + // No WithToolContext — should fail + result := tool.Execute(context.Background(), map[string]any{"path": "test.png"}) + if !result.IsError || result.ForLLM != "no target channel/chat available" { + t.Fatalf("expected channel error, got: %s", result.ForLLM) + } +} + +func TestLoadImage_NonImageFile(t *testing.T) { + dir := t.TempDir() + txtFile := filepath.Join(dir, "readme.txt") + os.WriteFile(txtFile, []byte("hello"), 0644) + + store := media.NewFileMediaStore() + tool := NewLoadImageTool(dir, false, 0, store) + ctx := WithToolContext(context.Background(), "test", "chat1") + result := tool.Execute(ctx, map[string]any{"path": txtFile}) + if !result.IsError { + t.Fatal("expected error for non-image file") + } +} + +func TestLoadImage_DefaultMaxSize(t *testing.T) { + tool := NewLoadImageTool("/tmp", false, 0, nil) + if tool.maxFileSize != config.DefaultMaxMediaSize { + t.Errorf("expected default max size %d, got %d", config.DefaultMaxMediaSize, tool.maxFileSize) + } +} + +func TestLoadImage_FileTooLarge(t *testing.T) { + dir := t.TempDir() + bigFile := filepath.Join(dir, "big.png") + // Create a file with PNG header but exceeding max size + data := make([]byte, 1024) + copy(data, []byte{0x89, 0x50, 0x4E, 0x47}) // PNG magic bytes + os.WriteFile(bigFile, data, 0644) + + store := media.NewFileMediaStore() + tool := NewLoadImageTool(dir, false, 512, store) // maxSize = 512 + ctx := WithToolContext(context.Background(), "test", "chat1") + result := tool.Execute(ctx, map[string]any{"path": bigFile}) + if !result.IsError { + t.Fatal("expected error for oversized file") + } +} diff --git a/pkg/tools/registry.go b/pkg/tools/registry.go index 56af8d695..bf2adb179 100644 --- a/pkg/tools/registry.go +++ b/pkg/tools/registry.go @@ -395,6 +395,18 @@ func (r *ToolRegistry) Clone() *ToolRegistry { return clone } +// Unregister removes a tool by name. Returns true if the tool was found and removed. +func (r *ToolRegistry) Unregister(name string) bool { + r.mu.Lock() + defer r.mu.Unlock() + if _, exists := r.tools[name]; !exists { + return false + } + delete(r.tools, name) + r.version.Add(1) + return true +} + // Count returns the number of registered tools. func (r *ToolRegistry) Count() int { r.mu.RLock()