fix: address load_image PR review feedback

- Exclude load_image from sub-agent tools via Unregister after Clone,
  since RunToolLoop does not call resolveMediaRefs
- Add ToolRegistry.Unregister() method
- Fix scope collision: use channel:chatID instead of filename
- Add channel/chatID context resolution matching send_file pattern
- Add comment explaining iteration > 1 guard on resolveMediaRefs
- Remove emoji from ForUser for consistency with send_file
- Add load_image_test.go
This commit is contained in:
reusu
2026-03-28 22:49:54 +08:00
parent 66924457bc
commit 28f69e71cc
4 changed files with 137 additions and 18 deletions
+10 -1
View File
@@ -363,7 +363,12 @@ func registerSharedTools(
// tools registered so far (file, web, etc.) but NOT spawn/ // tools registered so far (file, web, etc.) but NOT spawn/
// spawn_status which are added below — preventing recursive // spawn_status which are added below — preventing recursive
// subagent spawning. // subagent spawning.
subagentManager.SetTools(agent.Tools.Clone()) subagentTools := agent.Tools.Clone()
// load_image depends on resolveMediaRefs which only runs in
// the main agent loop, not in RunToolLoop. Remove it from
// sub-agent tools so the LLM won't call it in vain.
subagentTools.Unregister("load_image")
subagentManager.SetTools(subagentTools)
if spawnEnabled { if spawnEnabled {
spawnTool := tools.NewSpawnTool(subagentManager) spawnTool := tools.NewSpawnTool(subagentManager)
spawnTool.SetSpawner(NewSubTurnSpawner(al)) spawnTool.SetSpawner(NewSubTurnSpawner(al))
@@ -1817,6 +1822,10 @@ turnLoop:
providerToolDefs = filtered providerToolDefs = filtered
} }
// Resolve media:// refs produced by tool results (e.g. load_image).
// Skipped on iteration 1 because inbound user media is already resolved
// before entering the loop; only subsequent iterations can contain new
// tool-generated media refs that need base64 encoding.
if iteration > 1 { if iteration > 1 {
messages = resolveMediaRefs(messages, al.mediaStore, maxMediaSize) messages = resolveMediaRefs(messages, al.mediaStore, maxMediaSize)
} }
+38 -17
View File
@@ -22,33 +22,36 @@ import (
// - SendFileTool → MediaResult + WithResponseHandled() → sends file to user, ends turn // - SendFileTool → MediaResult + WithResponseHandled() → sends file to user, ends turn
// - LoadImageTool → plain ToolResult with media:// in ForLLM → LLM sees the image next turn // - LoadImageTool → plain ToolResult with media:// in ForLLM → LLM sees the image next turn
type LoadImageTool struct { type LoadImageTool struct {
workspace string workspace string
restrict bool restrict bool
maxSize int maxFileSize int
mediaStore media.MediaStore mediaStore media.MediaStore
allowPaths []*regexp.Regexp allowPaths []*regexp.Regexp
defaultChannel string
defaultChatID string
} }
func NewLoadImageTool( func NewLoadImageTool(
workspace string, workspace string,
restrict bool, restrict bool,
maxSize int, maxFileSize int,
store media.MediaStore, store media.MediaStore,
allowPaths ...[]*regexp.Regexp, allowPaths ...[]*regexp.Regexp,
) *LoadImageTool { ) *LoadImageTool {
if maxSize <= 0 { if maxFileSize <= 0 {
maxSize = config.DefaultMaxMediaSize maxFileSize = config.DefaultMaxMediaSize
} }
var patterns []*regexp.Regexp var patterns []*regexp.Regexp
if len(allowPaths) > 0 { if len(allowPaths) > 0 {
patterns = allowPaths[0] patterns = allowPaths[0]
} }
return &LoadImageTool{ return &LoadImageTool{
workspace: workspace, workspace: workspace,
restrict: restrict, restrict: restrict,
maxSize: maxSize, maxFileSize: maxFileSize,
mediaStore: store, mediaStore: store,
allowPaths: patterns, allowPaths: patterns,
} }
} }
@@ -77,6 +80,11 @@ func (t *LoadImageTool) Parameters() map[string]any {
} }
} }
func (t *LoadImageTool) SetContext(channel, chatID string) {
t.defaultChannel = channel
t.defaultChatID = chatID
}
func (t *LoadImageTool) SetMediaStore(store media.MediaStore) { func (t *LoadImageTool) SetMediaStore(store media.MediaStore) {
t.mediaStore = store t.mediaStore = store
} }
@@ -87,6 +95,19 @@ func (t *LoadImageTool) Execute(ctx context.Context, args map[string]any) *ToolR
return ErrorResult("path is required") return ErrorResult("path is required")
} }
// Prefer context-injected channel/chatID (set by ExecuteWithContext), fall back to SetContext values.
channel := ToolChannel(ctx)
if channel == "" {
channel = t.defaultChannel
}
chatID := ToolChatID(ctx)
if chatID == "" {
chatID = t.defaultChatID
}
if channel == "" || chatID == "" {
return ErrorResult("no target channel/chat available")
}
if t.mediaStore == nil { if t.mediaStore == nil {
return ErrorResult("media store not configured") return ErrorResult("media store not configured")
} }
@@ -103,9 +124,9 @@ func (t *LoadImageTool) Execute(ctx context.Context, args map[string]any) *ToolR
if info.IsDir() { if info.IsDir() {
return ErrorResult("path is a directory, expected an image file") return ErrorResult("path is a directory, expected an image file")
} }
if info.Size() > int64(t.maxSize) { if info.Size() > int64(t.maxFileSize) {
return ErrorResult(fmt.Sprintf( return ErrorResult(fmt.Sprintf(
"file too large: %d bytes (max %d bytes)", info.Size(), t.maxSize, "file too large: %d bytes (max %d bytes)", info.Size(), t.maxFileSize,
)) ))
} }
@@ -118,7 +139,7 @@ func (t *LoadImageTool) Execute(ctx context.Context, args map[string]any) *ToolR
} }
filename := filepath.Base(resolved) filename := filepath.Base(resolved)
scope := fmt.Sprintf("tool:load_image:%s", filename) scope := fmt.Sprintf("tool:load_image:%s:%s", channel, chatID)
ref, err := t.mediaStore.Store(resolved, media.MediaMeta{ ref, err := t.mediaStore.Store(resolved, media.MediaMeta{
Filename: filename, Filename: filename,
@@ -143,7 +164,7 @@ func (t *LoadImageTool) Execute(ctx context.Context, args map[string]any) *ToolR
return &ToolResult{ return &ToolResult{
ForLLM: msg, ForLLM: msg,
ForUser: fmt.Sprintf("📷 Loaded image: %s", filename), ForUser: fmt.Sprintf("Loaded image: %s", filename),
// Media refs inside ForLLM are resolved by resolveMediaRefs in the // Media refs inside ForLLM are resolved by resolveMediaRefs in the
// agent loop before the next LLM call. Do NOT use MediaResult here — // agent loop before the next LLM call. Do NOT use MediaResult here —
// that would send the file to the user channel instead. // that would send the file to the user channel instead.
+77
View File
@@ -0,0 +1,77 @@
package tools
import (
"context"
"os"
"path/filepath"
"testing"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/media"
)
func TestLoadImage_PathRequired(t *testing.T) {
tool := NewLoadImageTool("/tmp", false, 0, nil)
ctx := WithToolContext(context.Background(), "test", "chat1")
result := tool.Execute(ctx, map[string]any{})
if !result.IsError {
t.Fatal("expected error for missing path")
}
}
func TestLoadImage_NilMediaStore(t *testing.T) {
tool := NewLoadImageTool("/tmp", false, 0, nil)
ctx := WithToolContext(context.Background(), "test", "chat1")
result := tool.Execute(ctx, map[string]any{"path": "test.png"})
if !result.IsError || result.ForLLM != "media store not configured" {
t.Fatalf("expected media store error, got: %s", result.ForLLM)
}
}
func TestLoadImage_NoChannelContext(t *testing.T) {
store := media.NewFileMediaStore()
tool := NewLoadImageTool("/tmp", false, 0, store)
// No WithToolContext — should fail
result := tool.Execute(context.Background(), map[string]any{"path": "test.png"})
if !result.IsError || result.ForLLM != "no target channel/chat available" {
t.Fatalf("expected channel error, got: %s", result.ForLLM)
}
}
func TestLoadImage_NonImageFile(t *testing.T) {
dir := t.TempDir()
txtFile := filepath.Join(dir, "readme.txt")
os.WriteFile(txtFile, []byte("hello"), 0644)
store := media.NewFileMediaStore()
tool := NewLoadImageTool(dir, false, 0, store)
ctx := WithToolContext(context.Background(), "test", "chat1")
result := tool.Execute(ctx, map[string]any{"path": txtFile})
if !result.IsError {
t.Fatal("expected error for non-image file")
}
}
func TestLoadImage_DefaultMaxSize(t *testing.T) {
tool := NewLoadImageTool("/tmp", false, 0, nil)
if tool.maxFileSize != config.DefaultMaxMediaSize {
t.Errorf("expected default max size %d, got %d", config.DefaultMaxMediaSize, tool.maxFileSize)
}
}
func TestLoadImage_FileTooLarge(t *testing.T) {
dir := t.TempDir()
bigFile := filepath.Join(dir, "big.png")
// Create a file with PNG header but exceeding max size
data := make([]byte, 1024)
copy(data, []byte{0x89, 0x50, 0x4E, 0x47}) // PNG magic bytes
os.WriteFile(bigFile, data, 0644)
store := media.NewFileMediaStore()
tool := NewLoadImageTool(dir, false, 512, store) // maxSize = 512
ctx := WithToolContext(context.Background(), "test", "chat1")
result := tool.Execute(ctx, map[string]any{"path": bigFile})
if !result.IsError {
t.Fatal("expected error for oversized file")
}
}
+12
View File
@@ -395,6 +395,18 @@ func (r *ToolRegistry) Clone() *ToolRegistry {
return clone return clone
} }
// Unregister removes a tool by name. Returns true if the tool was found and removed.
func (r *ToolRegistry) Unregister(name string) bool {
r.mu.Lock()
defer r.mu.Unlock()
if _, exists := r.tools[name]; !exists {
return false
}
delete(r.tools, name)
r.version.Add(1)
return true
}
// Count returns the number of registered tools. // Count returns the number of registered tools.
func (r *ToolRegistry) Count() int { func (r *ToolRegistry) Count() int {
r.mu.RLock() r.mu.RLock()