mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
fix: address load_image PR review feedback
- Exclude load_image from sub-agent tools via Unregister after Clone, since RunToolLoop does not call resolveMediaRefs - Add ToolRegistry.Unregister() method - Fix scope collision: use channel:chatID instead of filename - Add channel/chatID context resolution matching send_file pattern - Add comment explaining iteration > 1 guard on resolveMediaRefs - Remove emoji from ForUser for consistency with send_file - Add load_image_test.go
This commit is contained in:
+10
-1
@@ -363,7 +363,12 @@ func registerSharedTools(
|
|||||||
// tools registered so far (file, web, etc.) but NOT spawn/
|
// tools registered so far (file, web, etc.) but NOT spawn/
|
||||||
// spawn_status which are added below — preventing recursive
|
// spawn_status which are added below — preventing recursive
|
||||||
// subagent spawning.
|
// subagent spawning.
|
||||||
subagentManager.SetTools(agent.Tools.Clone())
|
subagentTools := agent.Tools.Clone()
|
||||||
|
// load_image depends on resolveMediaRefs which only runs in
|
||||||
|
// the main agent loop, not in RunToolLoop. Remove it from
|
||||||
|
// sub-agent tools so the LLM won't call it in vain.
|
||||||
|
subagentTools.Unregister("load_image")
|
||||||
|
subagentManager.SetTools(subagentTools)
|
||||||
if spawnEnabled {
|
if spawnEnabled {
|
||||||
spawnTool := tools.NewSpawnTool(subagentManager)
|
spawnTool := tools.NewSpawnTool(subagentManager)
|
||||||
spawnTool.SetSpawner(NewSubTurnSpawner(al))
|
spawnTool.SetSpawner(NewSubTurnSpawner(al))
|
||||||
@@ -1817,6 +1822,10 @@ turnLoop:
|
|||||||
providerToolDefs = filtered
|
providerToolDefs = filtered
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Resolve media:// refs produced by tool results (e.g. load_image).
|
||||||
|
// Skipped on iteration 1 because inbound user media is already resolved
|
||||||
|
// before entering the loop; only subsequent iterations can contain new
|
||||||
|
// tool-generated media refs that need base64 encoding.
|
||||||
if iteration > 1 {
|
if iteration > 1 {
|
||||||
messages = resolveMediaRefs(messages, al.mediaStore, maxMediaSize)
|
messages = resolveMediaRefs(messages, al.mediaStore, maxMediaSize)
|
||||||
}
|
}
|
||||||
|
|||||||
+38
-17
@@ -22,33 +22,36 @@ import (
|
|||||||
// - SendFileTool → MediaResult + WithResponseHandled() → sends file to user, ends turn
|
// - SendFileTool → MediaResult + WithResponseHandled() → sends file to user, ends turn
|
||||||
// - LoadImageTool → plain ToolResult with media:// in ForLLM → LLM sees the image next turn
|
// - LoadImageTool → plain ToolResult with media:// in ForLLM → LLM sees the image next turn
|
||||||
type LoadImageTool struct {
|
type LoadImageTool struct {
|
||||||
workspace string
|
workspace string
|
||||||
restrict bool
|
restrict bool
|
||||||
maxSize int
|
maxFileSize int
|
||||||
mediaStore media.MediaStore
|
mediaStore media.MediaStore
|
||||||
allowPaths []*regexp.Regexp
|
allowPaths []*regexp.Regexp
|
||||||
|
|
||||||
|
defaultChannel string
|
||||||
|
defaultChatID string
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewLoadImageTool(
|
func NewLoadImageTool(
|
||||||
workspace string,
|
workspace string,
|
||||||
restrict bool,
|
restrict bool,
|
||||||
maxSize int,
|
maxFileSize int,
|
||||||
store media.MediaStore,
|
store media.MediaStore,
|
||||||
allowPaths ...[]*regexp.Regexp,
|
allowPaths ...[]*regexp.Regexp,
|
||||||
) *LoadImageTool {
|
) *LoadImageTool {
|
||||||
if maxSize <= 0 {
|
if maxFileSize <= 0 {
|
||||||
maxSize = config.DefaultMaxMediaSize
|
maxFileSize = config.DefaultMaxMediaSize
|
||||||
}
|
}
|
||||||
var patterns []*regexp.Regexp
|
var patterns []*regexp.Regexp
|
||||||
if len(allowPaths) > 0 {
|
if len(allowPaths) > 0 {
|
||||||
patterns = allowPaths[0]
|
patterns = allowPaths[0]
|
||||||
}
|
}
|
||||||
return &LoadImageTool{
|
return &LoadImageTool{
|
||||||
workspace: workspace,
|
workspace: workspace,
|
||||||
restrict: restrict,
|
restrict: restrict,
|
||||||
maxSize: maxSize,
|
maxFileSize: maxFileSize,
|
||||||
mediaStore: store,
|
mediaStore: store,
|
||||||
allowPaths: patterns,
|
allowPaths: patterns,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -77,6 +80,11 @@ func (t *LoadImageTool) Parameters() map[string]any {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *LoadImageTool) SetContext(channel, chatID string) {
|
||||||
|
t.defaultChannel = channel
|
||||||
|
t.defaultChatID = chatID
|
||||||
|
}
|
||||||
|
|
||||||
func (t *LoadImageTool) SetMediaStore(store media.MediaStore) {
|
func (t *LoadImageTool) SetMediaStore(store media.MediaStore) {
|
||||||
t.mediaStore = store
|
t.mediaStore = store
|
||||||
}
|
}
|
||||||
@@ -87,6 +95,19 @@ func (t *LoadImageTool) Execute(ctx context.Context, args map[string]any) *ToolR
|
|||||||
return ErrorResult("path is required")
|
return ErrorResult("path is required")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Prefer context-injected channel/chatID (set by ExecuteWithContext), fall back to SetContext values.
|
||||||
|
channel := ToolChannel(ctx)
|
||||||
|
if channel == "" {
|
||||||
|
channel = t.defaultChannel
|
||||||
|
}
|
||||||
|
chatID := ToolChatID(ctx)
|
||||||
|
if chatID == "" {
|
||||||
|
chatID = t.defaultChatID
|
||||||
|
}
|
||||||
|
if channel == "" || chatID == "" {
|
||||||
|
return ErrorResult("no target channel/chat available")
|
||||||
|
}
|
||||||
|
|
||||||
if t.mediaStore == nil {
|
if t.mediaStore == nil {
|
||||||
return ErrorResult("media store not configured")
|
return ErrorResult("media store not configured")
|
||||||
}
|
}
|
||||||
@@ -103,9 +124,9 @@ func (t *LoadImageTool) Execute(ctx context.Context, args map[string]any) *ToolR
|
|||||||
if info.IsDir() {
|
if info.IsDir() {
|
||||||
return ErrorResult("path is a directory, expected an image file")
|
return ErrorResult("path is a directory, expected an image file")
|
||||||
}
|
}
|
||||||
if info.Size() > int64(t.maxSize) {
|
if info.Size() > int64(t.maxFileSize) {
|
||||||
return ErrorResult(fmt.Sprintf(
|
return ErrorResult(fmt.Sprintf(
|
||||||
"file too large: %d bytes (max %d bytes)", info.Size(), t.maxSize,
|
"file too large: %d bytes (max %d bytes)", info.Size(), t.maxFileSize,
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -118,7 +139,7 @@ func (t *LoadImageTool) Execute(ctx context.Context, args map[string]any) *ToolR
|
|||||||
}
|
}
|
||||||
|
|
||||||
filename := filepath.Base(resolved)
|
filename := filepath.Base(resolved)
|
||||||
scope := fmt.Sprintf("tool:load_image:%s", filename)
|
scope := fmt.Sprintf("tool:load_image:%s:%s", channel, chatID)
|
||||||
|
|
||||||
ref, err := t.mediaStore.Store(resolved, media.MediaMeta{
|
ref, err := t.mediaStore.Store(resolved, media.MediaMeta{
|
||||||
Filename: filename,
|
Filename: filename,
|
||||||
@@ -143,7 +164,7 @@ func (t *LoadImageTool) Execute(ctx context.Context, args map[string]any) *ToolR
|
|||||||
|
|
||||||
return &ToolResult{
|
return &ToolResult{
|
||||||
ForLLM: msg,
|
ForLLM: msg,
|
||||||
ForUser: fmt.Sprintf("📷 Loaded image: %s", filename),
|
ForUser: fmt.Sprintf("Loaded image: %s", filename),
|
||||||
// Media refs inside ForLLM are resolved by resolveMediaRefs in the
|
// Media refs inside ForLLM are resolved by resolveMediaRefs in the
|
||||||
// agent loop before the next LLM call. Do NOT use MediaResult here —
|
// agent loop before the next LLM call. Do NOT use MediaResult here —
|
||||||
// that would send the file to the user channel instead.
|
// that would send the file to the user channel instead.
|
||||||
|
|||||||
@@ -0,0 +1,77 @@
|
|||||||
|
package tools
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/sipeed/picoclaw/pkg/config"
|
||||||
|
"github.com/sipeed/picoclaw/pkg/media"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestLoadImage_PathRequired(t *testing.T) {
|
||||||
|
tool := NewLoadImageTool("/tmp", false, 0, nil)
|
||||||
|
ctx := WithToolContext(context.Background(), "test", "chat1")
|
||||||
|
result := tool.Execute(ctx, map[string]any{})
|
||||||
|
if !result.IsError {
|
||||||
|
t.Fatal("expected error for missing path")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadImage_NilMediaStore(t *testing.T) {
|
||||||
|
tool := NewLoadImageTool("/tmp", false, 0, nil)
|
||||||
|
ctx := WithToolContext(context.Background(), "test", "chat1")
|
||||||
|
result := tool.Execute(ctx, map[string]any{"path": "test.png"})
|
||||||
|
if !result.IsError || result.ForLLM != "media store not configured" {
|
||||||
|
t.Fatalf("expected media store error, got: %s", result.ForLLM)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadImage_NoChannelContext(t *testing.T) {
|
||||||
|
store := media.NewFileMediaStore()
|
||||||
|
tool := NewLoadImageTool("/tmp", false, 0, store)
|
||||||
|
// No WithToolContext — should fail
|
||||||
|
result := tool.Execute(context.Background(), map[string]any{"path": "test.png"})
|
||||||
|
if !result.IsError || result.ForLLM != "no target channel/chat available" {
|
||||||
|
t.Fatalf("expected channel error, got: %s", result.ForLLM)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadImage_NonImageFile(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
txtFile := filepath.Join(dir, "readme.txt")
|
||||||
|
os.WriteFile(txtFile, []byte("hello"), 0644)
|
||||||
|
|
||||||
|
store := media.NewFileMediaStore()
|
||||||
|
tool := NewLoadImageTool(dir, false, 0, store)
|
||||||
|
ctx := WithToolContext(context.Background(), "test", "chat1")
|
||||||
|
result := tool.Execute(ctx, map[string]any{"path": txtFile})
|
||||||
|
if !result.IsError {
|
||||||
|
t.Fatal("expected error for non-image file")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadImage_DefaultMaxSize(t *testing.T) {
|
||||||
|
tool := NewLoadImageTool("/tmp", false, 0, nil)
|
||||||
|
if tool.maxFileSize != config.DefaultMaxMediaSize {
|
||||||
|
t.Errorf("expected default max size %d, got %d", config.DefaultMaxMediaSize, tool.maxFileSize)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadImage_FileTooLarge(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
bigFile := filepath.Join(dir, "big.png")
|
||||||
|
// Create a file with PNG header but exceeding max size
|
||||||
|
data := make([]byte, 1024)
|
||||||
|
copy(data, []byte{0x89, 0x50, 0x4E, 0x47}) // PNG magic bytes
|
||||||
|
os.WriteFile(bigFile, data, 0644)
|
||||||
|
|
||||||
|
store := media.NewFileMediaStore()
|
||||||
|
tool := NewLoadImageTool(dir, false, 512, store) // maxSize = 512
|
||||||
|
ctx := WithToolContext(context.Background(), "test", "chat1")
|
||||||
|
result := tool.Execute(ctx, map[string]any{"path": bigFile})
|
||||||
|
if !result.IsError {
|
||||||
|
t.Fatal("expected error for oversized file")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -395,6 +395,18 @@ func (r *ToolRegistry) Clone() *ToolRegistry {
|
|||||||
return clone
|
return clone
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Unregister removes a tool by name. Returns true if the tool was found and removed.
|
||||||
|
func (r *ToolRegistry) Unregister(name string) bool {
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
if _, exists := r.tools[name]; !exists {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
delete(r.tools, name)
|
||||||
|
r.version.Add(1)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
// Count returns the number of registered tools.
|
// Count returns the number of registered tools.
|
||||||
func (r *ToolRegistry) Count() int {
|
func (r *ToolRegistry) Count() int {
|
||||||
r.mu.RLock()
|
r.mu.RLock()
|
||||||
|
|||||||
Reference in New Issue
Block a user