Merge remote-tracking branch 'refs/remotes/origin/main' into fix/deny-reading-binary-files

This commit is contained in:
afjcjsbx
2026-03-09 00:30:37 +01:00
138 changed files with 9432 additions and 953 deletions
+50 -38
View File
@@ -10,11 +10,38 @@ type Tool interface {
Execute(ctx context.Context, args map[string]any) *ToolResult
}
// ContextualTool is an optional interface that tools can implement
// to receive the current message context (channel, chatID)
type ContextualTool interface {
Tool
SetContext(channel, chatID string)
// --- Request-scoped tool context (channel / chatID) ---
//
// Carried via context.Value so that concurrent tool calls each receive
// their own immutable copy — no mutable state on singleton tool instances.
//
// Keys are unexported pointer-typed vars — guaranteed collision-free,
// and only accessible through the helper functions below.
type toolCtxKey struct{ name string }
var (
ctxKeyChannel = &toolCtxKey{"channel"}
ctxKeyChatID = &toolCtxKey{"chatID"}
)
// WithToolContext returns a child context carrying channel and chatID.
func WithToolContext(ctx context.Context, channel, chatID string) context.Context {
ctx = context.WithValue(ctx, ctxKeyChannel, channel)
ctx = context.WithValue(ctx, ctxKeyChatID, chatID)
return ctx
}
// ToolChannel extracts the channel from ctx, or "" if unset.
func ToolChannel(ctx context.Context) string {
v, _ := ctx.Value(ctxKeyChannel).(string)
return v
}
// ToolChatID extracts the chatID from ctx, or "" if unset.
func ToolChatID(ctx context.Context) string {
v, _ := ctx.Value(ctxKeyChatID).(string)
return v
}
// AsyncCallback is a function type that async tools use to notify completion.
@@ -22,51 +49,36 @@ type ContextualTool interface {
//
// The ctx parameter allows the callback to be canceled if the agent is shutting down.
// The result parameter contains the tool's execution result.
//
// Example usage in an async tool:
//
// func (t *MyAsyncTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
// // Start async work in background
// go func() {
// result := doAsyncWork()
// if t.callback != nil {
// t.callback(ctx, result)
// }
// }()
// return AsyncResult("Async task started")
// }
type AsyncCallback func(ctx context.Context, result *ToolResult)
// AsyncTool is an optional interface that tools can implement to support
// AsyncExecutor is an optional interface that tools can implement to support
// asynchronous execution with completion callbacks.
//
// Async tools return immediately with an AsyncResult, then notify completion
// via the callback set by SetCallback.
// Unlike the old AsyncTool pattern (SetCallback + Execute), AsyncExecutor
// receives the callback as a parameter of ExecuteAsync. This eliminates the
// data race where concurrent calls could overwrite each other's callbacks
// on a shared tool instance.
//
// This is useful for:
// - Long-running operations that shouldn't block the agent loop
// - Subagent spawns that complete independently
// - Background tasks that need to report results later
// - Long-running operations that shouldn't block the agent loop
// - Subagent spawns that complete independently
// - Background tasks that need to report results later
//
// Example:
//
// type SpawnTool struct {
// callback AsyncCallback
// }
//
// func (t *SpawnTool) SetCallback(cb AsyncCallback) {
// t.callback = cb
// }
//
// func (t *SpawnTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
// go t.runSubagent(ctx, args)
// func (t *SpawnTool) ExecuteAsync(ctx context.Context, args map[string]any, cb AsyncCallback) *ToolResult {
// go func() {
// result := t.runSubagent(ctx, args)
// if cb != nil { cb(ctx, result) }
// }()
// return AsyncResult("Subagent spawned, will report back")
// }
type AsyncTool interface {
type AsyncExecutor interface {
Tool
// SetCallback registers a callback function to be invoked when the async operation completes.
// The callback will be called from a goroutine and should handle thread-safety if needed.
SetCallback(cb AsyncCallback)
// ExecuteAsync runs the tool asynchronously. The callback cb will be
// invoked (possibly from another goroutine) when the async operation
// completes. cb is guaranteed to be non-nil by the caller (registry).
ExecuteAsync(ctx context.Context, args map[string]any, cb AsyncCallback) *ToolResult
}
func ToolToSchema(tool Tool) map[string]any {
+10 -18
View File
@@ -4,7 +4,6 @@ import (
"context"
"fmt"
"strings"
"sync"
"time"
"github.com/sipeed/picoclaw/pkg/bus"
@@ -24,9 +23,6 @@ type CronTool struct {
executor JobExecutor
msgBus *bus.MessageBus
execTool *ExecTool
channel string
chatID string
mu sync.RWMutex
}
// NewCronTool creates a new CronTool
@@ -102,14 +98,6 @@ func (t *CronTool) Parameters() map[string]any {
}
}
// SetContext sets the current session context for job creation
func (t *CronTool) SetContext(channel, chatID string) {
t.mu.Lock()
defer t.mu.Unlock()
t.channel = channel
t.chatID = chatID
}
// Execute runs the tool with the given arguments
func (t *CronTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
action, ok := args["action"].(string)
@@ -119,7 +107,7 @@ func (t *CronTool) Execute(ctx context.Context, args map[string]any) *ToolResult
switch action {
case "add":
return t.addJob(args)
return t.addJob(ctx, args)
case "list":
return t.listJobs()
case "remove":
@@ -133,11 +121,9 @@ func (t *CronTool) Execute(ctx context.Context, args map[string]any) *ToolResult
}
}
func (t *CronTool) addJob(args map[string]any) *ToolResult {
t.mu.RLock()
channel := t.channel
chatID := t.chatID
t.mu.RUnlock()
func (t *CronTool) addJob(ctx context.Context, args map[string]any) *ToolResult {
channel := ToolChannel(ctx)
chatID := ToolChatID(ctx)
if channel == "" || chatID == "" {
return ErrorResult("no session context (channel/chat_id not set). Use this tool in an active conversation.")
@@ -155,6 +141,12 @@ func (t *CronTool) addJob(args map[string]any) *ToolResult {
everySeconds, hasEvery := args["every_seconds"].(float64)
cronExpr, hasCron := args["cron_expr"].(string)
// Fix: type assertions return true for zero values, need additional validity checks
// This prevents LLMs that fill unused optional parameters with defaults (0) from triggering wrong type
hasAt = hasAt && atSeconds > 0
hasEvery = hasEvery && everySeconds > 0
hasCron = hasCron && cronExpr != ""
// Priority: at_seconds > every_seconds > cron_expr
if hasAt {
atMS := time.Now().UnixMilli() + int64(atSeconds)*1000
+8 -10
View File
@@ -9,10 +9,8 @@ import (
type SendCallback func(channel, chatID, content string) error
type MessageTool struct {
sendCallback SendCallback
defaultChannel string
defaultChatID string
sentInRound atomic.Bool // Tracks whether a message was sent in the current processing round
sendCallback SendCallback
sentInRound atomic.Bool // Tracks whether a message was sent in the current processing round
}
func NewMessageTool() *MessageTool {
@@ -48,10 +46,10 @@ func (t *MessageTool) Parameters() map[string]any {
}
}
func (t *MessageTool) SetContext(channel, chatID string) {
t.defaultChannel = channel
t.defaultChatID = chatID
t.sentInRound.Store(false) // Reset send tracking for new processing round
// ResetSentInRound resets the per-round send tracker.
// Called by the agent loop at the start of each inbound message processing round.
func (t *MessageTool) ResetSentInRound() {
t.sentInRound.Store(false)
}
// HasSentInRound returns true if the message tool sent a message during the current round.
@@ -73,10 +71,10 @@ func (t *MessageTool) Execute(ctx context.Context, args map[string]any) *ToolRes
chatID, _ := args["chat_id"].(string)
if channel == "" {
channel = t.defaultChannel
channel = ToolChannel(ctx)
}
if chatID == "" {
chatID = t.defaultChatID
chatID = ToolChatID(ctx)
}
if channel == "" || chatID == "" {
+6 -11
View File
@@ -8,7 +8,6 @@ import (
func TestMessageTool_Execute_Success(t *testing.T) {
tool := NewMessageTool()
tool.SetContext("test-channel", "test-chat-id")
var sentChannel, sentChatID, sentContent string
tool.SetSendCallback(func(channel, chatID, content string) error {
@@ -18,7 +17,7 @@ func TestMessageTool_Execute_Success(t *testing.T) {
return nil
})
ctx := context.Background()
ctx := WithToolContext(context.Background(), "test-channel", "test-chat-id")
args := map[string]any{
"content": "Hello, world!",
}
@@ -60,7 +59,6 @@ func TestMessageTool_Execute_Success(t *testing.T) {
func TestMessageTool_Execute_WithCustomChannel(t *testing.T) {
tool := NewMessageTool()
tool.SetContext("default-channel", "default-chat-id")
var sentChannel, sentChatID string
tool.SetSendCallback(func(channel, chatID, content string) error {
@@ -69,7 +67,7 @@ func TestMessageTool_Execute_WithCustomChannel(t *testing.T) {
return nil
})
ctx := context.Background()
ctx := WithToolContext(context.Background(), "default-channel", "default-chat-id")
args := map[string]any{
"content": "Test message",
"channel": "custom-channel",
@@ -96,14 +94,13 @@ func TestMessageTool_Execute_WithCustomChannel(t *testing.T) {
func TestMessageTool_Execute_SendFailure(t *testing.T) {
tool := NewMessageTool()
tool.SetContext("test-channel", "test-chat-id")
sendErr := errors.New("network error")
tool.SetSendCallback(func(channel, chatID, content string) error {
return sendErr
})
ctx := context.Background()
ctx := WithToolContext(context.Background(), "test-channel", "test-chat-id")
args := map[string]any{
"content": "Test message",
}
@@ -133,9 +130,8 @@ func TestMessageTool_Execute_SendFailure(t *testing.T) {
func TestMessageTool_Execute_MissingContent(t *testing.T) {
tool := NewMessageTool()
tool.SetContext("test-channel", "test-chat-id")
ctx := context.Background()
ctx := WithToolContext(context.Background(), "test-channel", "test-chat-id")
args := map[string]any{} // content missing
result := tool.Execute(ctx, args)
@@ -151,7 +147,7 @@ func TestMessageTool_Execute_MissingContent(t *testing.T) {
func TestMessageTool_Execute_NoTargetChannel(t *testing.T) {
tool := NewMessageTool()
// No SetContext called, so defaultChannel and defaultChatID are empty
// No WithToolContext — channel/chatID are empty
tool.SetSendCallback(func(channel, chatID, content string) error {
return nil
@@ -175,10 +171,9 @@ func TestMessageTool_Execute_NoTargetChannel(t *testing.T) {
func TestMessageTool_Execute_NotConfigured(t *testing.T) {
tool := NewMessageTool()
tool.SetContext("test-channel", "test-chat-id")
// No SetSendCallback called
ctx := context.Background()
ctx := WithToolContext(context.Background(), "test-channel", "test-chat-id")
args := map[string]any{
"content": "Test message",
}
+15 -13
View File
@@ -45,8 +45,9 @@ func (r *ToolRegistry) Execute(ctx context.Context, name string, args map[string
}
// ExecuteWithContext executes a tool with channel/chatID context and optional async callback.
// If the tool implements AsyncTool and a non-nil callback is provided,
// the callback will be set on the tool before execution.
// If the tool implements AsyncExecutor and a non-nil callback is provided,
// ExecuteAsync is called instead of Execute — the callback is a parameter,
// never stored as mutable state on the tool.
func (r *ToolRegistry) ExecuteWithContext(
ctx context.Context,
name string,
@@ -69,22 +70,23 @@ func (r *ToolRegistry) ExecuteWithContext(
return ErrorResult(fmt.Sprintf("tool %q not found", name)).WithError(fmt.Errorf("tool not found"))
}
// If tool implements ContextualTool, set context
if contextualTool, ok := tool.(ContextualTool); ok && channel != "" && chatID != "" {
contextualTool.SetContext(channel, chatID)
}
// Inject channel/chatID into ctx so tools read them via ToolChannel(ctx)/ToolChatID(ctx).
// Always inject — tools validate what they require.
ctx = WithToolContext(ctx, channel, chatID)
// If tool implements AsyncTool and callback is provided, set callback
if asyncTool, ok := tool.(AsyncTool); ok && asyncCallback != nil {
asyncTool.SetCallback(asyncCallback)
logger.DebugCF("tool", "Async callback injected",
// If tool implements AsyncExecutor and callback is provided, use ExecuteAsync.
// The callback is a call parameter, not mutable state on the tool instance.
var result *ToolResult
start := time.Now()
if asyncExec, ok := tool.(AsyncExecutor); ok && asyncCallback != nil {
logger.DebugCF("tool", "Executing async tool via ExecuteAsync",
map[string]any{
"tool": name,
})
result = asyncExec.ExecuteAsync(ctx, args, asyncCallback)
} else {
result = tool.Execute(ctx, args)
}
start := time.Now()
result := tool.Execute(ctx, args)
duration := time.Since(start)
// Log based on result type
+32 -22
View File
@@ -25,24 +25,24 @@ func (m *mockRegistryTool) Execute(_ context.Context, _ map[string]any) *ToolRes
return m.result
}
type mockCtxTool struct {
type mockContextAwareTool struct {
mockRegistryTool
channel string
chatID string
lastCtx context.Context
}
func (m *mockCtxTool) SetContext(channel, chatID string) {
m.channel = channel
m.chatID = chatID
func (m *mockContextAwareTool) Execute(ctx context.Context, _ map[string]any) *ToolResult {
m.lastCtx = ctx
return m.result
}
type mockAsyncRegistryTool struct {
mockRegistryTool
cb AsyncCallback
lastCB AsyncCallback
}
func (m *mockAsyncRegistryTool) SetCallback(cb AsyncCallback) {
m.cb = cb
func (m *mockAsyncRegistryTool) ExecuteAsync(_ context.Context, args map[string]any, cb AsyncCallback) *ToolResult {
m.lastCB = cb
return m.result
}
// --- helpers ---
@@ -136,34 +136,44 @@ func TestToolRegistry_Execute_NotFound(t *testing.T) {
}
}
func TestToolRegistry_ExecuteWithContext_ContextualTool(t *testing.T) {
func TestToolRegistry_ExecuteWithContext_InjectsToolContext(t *testing.T) {
r := NewToolRegistry()
ct := &mockCtxTool{
ct := &mockContextAwareTool{
mockRegistryTool: *newMockTool("ctx_tool", "needs context"),
}
r.Register(ct)
r.ExecuteWithContext(context.Background(), "ctx_tool", nil, "telegram", "chat-42", nil)
if ct.channel != "telegram" {
t.Errorf("expected channel 'telegram', got %q", ct.channel)
if ct.lastCtx == nil {
t.Fatal("expected Execute to be called")
}
if ct.chatID != "chat-42" {
t.Errorf("expected chatID 'chat-42', got %q", ct.chatID)
if got := ToolChannel(ct.lastCtx); got != "telegram" {
t.Errorf("expected channel 'telegram', got %q", got)
}
if got := ToolChatID(ct.lastCtx); got != "chat-42" {
t.Errorf("expected chatID 'chat-42', got %q", got)
}
}
func TestToolRegistry_ExecuteWithContext_SkipsEmptyContext(t *testing.T) {
func TestToolRegistry_ExecuteWithContext_EmptyContext(t *testing.T) {
r := NewToolRegistry()
ct := &mockCtxTool{
ct := &mockContextAwareTool{
mockRegistryTool: *newMockTool("ctx_tool", "needs context"),
}
r.Register(ct)
r.ExecuteWithContext(context.Background(), "ctx_tool", nil, "", "", nil)
if ct.channel != "" || ct.chatID != "" {
t.Error("SetContext should not be called with empty channel/chatID")
if ct.lastCtx == nil {
t.Fatal("expected Execute to be called")
}
// Empty values are still injected; tools decide what to do with them.
if got := ToolChannel(ct.lastCtx); got != "" {
t.Errorf("expected empty channel, got %q", got)
}
if got := ToolChatID(ct.lastCtx); got != "" {
t.Errorf("expected empty chatID, got %q", got)
}
}
@@ -179,14 +189,14 @@ func TestToolRegistry_ExecuteWithContext_AsyncCallback(t *testing.T) {
cb := func(_ context.Context, _ *ToolResult) { called = true }
result := r.ExecuteWithContext(context.Background(), "async_tool", nil, "", "", cb)
if at.cb == nil {
t.Error("expected SetCallback to have been called")
if at.lastCB == nil {
t.Error("expected ExecuteAsync to have received a callback")
}
if !result.Async {
t.Error("expected async result")
}
at.cb(context.Background(), SilentResult("done"))
at.lastCB(context.Background(), SilentResult("done"))
if !called {
t.Error("expected callback to be invoked")
}
+150
View File
@@ -0,0 +1,150 @@
package tools
import (
"context"
"fmt"
"mime"
"os"
"path/filepath"
"strings"
"github.com/h2non/filetype"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/media"
)
// SendFileTool allows the LLM to send a local file (image, document, etc.)
// to the user on the current chat channel via the MediaStore pipeline.
type SendFileTool struct {
workspace string
restrict bool
maxFileSize int
mediaStore media.MediaStore
defaultChannel string
defaultChatID string
}
func NewSendFileTool(workspace string, restrict bool, maxFileSize int, store media.MediaStore) *SendFileTool {
if maxFileSize <= 0 {
maxFileSize = config.DefaultMaxMediaSize
}
return &SendFileTool{
workspace: workspace,
restrict: restrict,
maxFileSize: maxFileSize,
mediaStore: store,
}
}
func (t *SendFileTool) Name() string { return "send_file" }
func (t *SendFileTool) Description() string {
return "Send a local file (image, document, etc.) to the user on the current chat channel."
}
func (t *SendFileTool) Parameters() map[string]any {
return map[string]any{
"type": "object",
"properties": map[string]any{
"path": map[string]any{
"type": "string",
"description": "Path to the local file. Relative paths are resolved from workspace.",
},
"filename": map[string]any{
"type": "string",
"description": "Optional display filename. Defaults to the basename of path.",
},
},
"required": []string{"path"},
}
}
func (t *SendFileTool) SetContext(channel, chatID string) {
t.defaultChannel = channel
t.defaultChatID = chatID
}
func (t *SendFileTool) SetMediaStore(store media.MediaStore) {
t.mediaStore = store
}
func (t *SendFileTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
path, _ := args["path"].(string)
if strings.TrimSpace(path) == "" {
return ErrorResult("path is required")
}
// Prefer context-injected channel/chatID (set by ExecuteWithContext), fall back to SetContext values.
channel := ToolChannel(ctx)
if channel == "" {
channel = t.defaultChannel
}
chatID := ToolChatID(ctx)
if chatID == "" {
chatID = t.defaultChatID
}
if channel == "" || chatID == "" {
return ErrorResult("no target channel/chat available")
}
if t.mediaStore == nil {
return ErrorResult("media store not configured")
}
resolved, err := validatePath(path, t.workspace, t.restrict)
if err != nil {
return ErrorResult(fmt.Sprintf("invalid path: %v", err))
}
info, err := os.Stat(resolved)
if err != nil {
return ErrorResult(fmt.Sprintf("file not found: %v", err))
}
if info.IsDir() {
return ErrorResult("path is a directory, expected a file")
}
if info.Size() > int64(t.maxFileSize) {
return ErrorResult(fmt.Sprintf(
"file too large: %d bytes (max %d bytes)",
info.Size(), t.maxFileSize,
))
}
filename, _ := args["filename"].(string)
if filename == "" {
filename = filepath.Base(resolved)
}
mediaType := detectMediaType(resolved)
scope := fmt.Sprintf("tool:send_file:%s:%s", channel, chatID)
ref, err := t.mediaStore.Store(resolved, media.MediaMeta{
Filename: filename,
ContentType: mediaType,
Source: "tool:send_file",
}, scope)
if err != nil {
return ErrorResult(fmt.Sprintf("failed to register media: %v", err))
}
return MediaResult(fmt.Sprintf("File %q sent to user", filename), []string{ref})
}
// detectMediaType determines the MIME type of a file.
// Uses magic-bytes detection (h2non/filetype) first, then falls back to
// extension-based lookup via mime.TypeByExtension.
func detectMediaType(path string) string {
kind, err := filetype.MatchFile(path)
if err == nil && kind != filetype.Unknown {
return kind.MIME.Value
}
if ext := filepath.Ext(path); ext != "" {
if t := mime.TypeByExtension(ext); t != "" {
return t
}
}
return "application/octet-stream"
}
+176
View File
@@ -0,0 +1,176 @@
package tools
import (
"context"
"os"
"path/filepath"
"strings"
"testing"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/media"
)
func TestSendFileTool_MissingPath(t *testing.T) {
store := media.NewFileMediaStore()
tool := NewSendFileTool("/tmp", false, 0, store)
tool.SetContext("feishu", "chat123")
result := tool.Execute(context.Background(), map[string]any{})
if !result.IsError {
t.Fatal("expected error for missing path")
}
}
func TestSendFileTool_NoContext(t *testing.T) {
store := media.NewFileMediaStore()
tool := NewSendFileTool("/tmp", false, 0, store)
// no SetContext call
result := tool.Execute(context.Background(), map[string]any{"path": "/tmp/test.txt"})
if !result.IsError {
t.Fatal("expected error when no channel context")
}
}
func TestSendFileTool_NoMediaStore(t *testing.T) {
tool := NewSendFileTool("/tmp", false, 0, nil)
tool.SetContext("feishu", "chat123")
result := tool.Execute(context.Background(), map[string]any{"path": "/tmp/test.txt"})
if !result.IsError {
t.Fatal("expected error when no media store")
}
}
func TestSendFileTool_Directory(t *testing.T) {
store := media.NewFileMediaStore()
tool := NewSendFileTool("/tmp", false, 0, store)
tool.SetContext("feishu", "chat123")
result := tool.Execute(context.Background(), map[string]any{"path": "/tmp"})
if !result.IsError {
t.Fatal("expected error for directory path")
}
}
func TestSendFileTool_FileTooLarge(t *testing.T) {
dir := t.TempDir()
testFile := filepath.Join(dir, "big.bin")
// Create a file larger than the limit
if err := os.WriteFile(testFile, make([]byte, 1024), 0o644); err != nil {
t.Fatal(err)
}
store := media.NewFileMediaStore()
tool := NewSendFileTool(dir, false, 512, store) // 512 byte limit
tool.SetContext("feishu", "chat123")
result := tool.Execute(context.Background(), map[string]any{"path": testFile})
if !result.IsError {
t.Fatal("expected error for oversized file")
}
if !strings.Contains(result.ForLLM, "too large") {
t.Errorf("expected 'too large' in error, got %q", result.ForLLM)
}
}
func TestSendFileTool_DefaultMaxSize(t *testing.T) {
tool := NewSendFileTool("/tmp", false, 0, nil)
if tool.maxFileSize != config.DefaultMaxMediaSize {
t.Errorf("expected default max size %d, got %d", config.DefaultMaxMediaSize, tool.maxFileSize)
}
}
func TestSendFileTool_Success(t *testing.T) {
dir := t.TempDir()
testFile := filepath.Join(dir, "photo.png")
if err := os.WriteFile(testFile, []byte("fake png"), 0o644); err != nil {
t.Fatal(err)
}
store := media.NewFileMediaStore()
tool := NewSendFileTool(dir, false, 0, store)
tool.SetContext("feishu", "chat123")
result := tool.Execute(context.Background(), map[string]any{"path": testFile})
if result.IsError {
t.Fatalf("unexpected error: %s", result.ForLLM)
}
if len(result.Media) != 1 {
t.Fatalf("expected 1 media ref, got %d", len(result.Media))
}
if result.Media[0][:8] != "media://" {
t.Errorf("expected media:// ref, got %q", result.Media[0])
}
}
func TestSendFileTool_CustomFilename(t *testing.T) {
dir := t.TempDir()
testFile := filepath.Join(dir, "img.jpg")
if err := os.WriteFile(testFile, []byte("fake jpg"), 0o644); err != nil {
t.Fatal(err)
}
store := media.NewFileMediaStore()
tool := NewSendFileTool(dir, false, 0, store)
tool.SetContext("telegram", "chat456")
result := tool.Execute(context.Background(), map[string]any{
"path": testFile,
"filename": "my-photo.jpg",
})
if result.IsError {
t.Fatalf("unexpected error: %s", result.ForLLM)
}
if len(result.Media) != 1 {
t.Fatalf("expected 1 media ref, got %d", len(result.Media))
}
}
func TestDetectMediaType_MagicBytes(t *testing.T) {
dir := t.TempDir()
// Minimal valid PNG header
pngHeader := []byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A}
pngFile := filepath.Join(dir, "image.dat") // wrong extension, but valid PNG bytes
if err := os.WriteFile(pngFile, pngHeader, 0o644); err != nil {
t.Fatal(err)
}
got := detectMediaType(pngFile)
if got != "image/png" {
t.Errorf("expected image/png from magic bytes, got %q", got)
}
}
func TestDetectMediaType_FallbackToExtension(t *testing.T) {
dir := t.TempDir()
// File with unrecognizable content but known extension
txtFile := filepath.Join(dir, "readme.txt")
if err := os.WriteFile(txtFile, []byte("hello world"), 0o644); err != nil {
t.Fatal(err)
}
got := detectMediaType(txtFile)
// text/plain or similar — just verify it's not application/octet-stream
if got == "application/octet-stream" {
t.Errorf("expected extension-based MIME for .txt, got %q", got)
}
}
func TestDetectMediaType_UnknownFallsToOctetStream(t *testing.T) {
dir := t.TempDir()
// File with no extension and random bytes
unknownFile := filepath.Join(dir, "mystery")
if err := os.WriteFile(unknownFile, []byte{0x00, 0x01, 0x02}, 0o644); err != nil {
t.Fatal(err)
}
got := detectMediaType(unknownFile)
if got != "application/octet-stream" {
t.Errorf("expected application/octet-stream, got %q", got)
}
}
+7 -2
View File
@@ -59,7 +59,7 @@ var (
regexp.MustCompile(`\bchown\b`),
regexp.MustCompile(`\bpkill\b`),
regexp.MustCompile(`\bkillall\b`),
regexp.MustCompile(`\bkill\s+-[9]\b`),
regexp.MustCompile(`\bkill\b`),
regexp.MustCompile(`\bcurl\b.*\|\s*(sh|bash)`),
regexp.MustCompile(`\bwget\b.*\|\s*(sh|bash)`),
regexp.MustCompile(`\bnpm\s+install\s+-g\b`),
@@ -131,9 +131,14 @@ func NewExecToolWithConfig(workingDir string, restrict bool, config *config.Conf
denyPatterns = append(denyPatterns, defaultDenyPatterns...)
}
timeout := 60 * time.Second
if config != nil && config.Tools.Exec.TimeoutSeconds > 0 {
timeout = time.Duration(config.Tools.Exec.TimeoutSeconds) * time.Second
}
return &ExecTool{
workingDir: workingDir,
timeout: 60 * time.Second,
timeout: timeout,
denyPatterns: denyPatterns,
allowPatterns: nil,
customAllowPatterns: customAllowPatterns,
+20
View File
@@ -151,6 +151,26 @@ func TestShellTool_DangerousCommand(t *testing.T) {
}
}
func TestShellTool_DangerousCommand_KillBlocked(t *testing.T) {
tool, err := NewExecTool("", false)
if err != nil {
t.Errorf("unable to configure exec tool: %s", err)
}
ctx := context.Background()
args := map[string]any{
"command": "kill 12345",
}
result := tool.Execute(ctx, args)
if !result.IsError {
t.Errorf("Expected kill command to be blocked")
}
if !strings.Contains(result.ForLLM, "blocked") && !strings.Contains(result.ForUser, "blocked") {
t.Errorf("Expected blocked message, got ForLLM: %s, ForUser: %s", result.ForLLM, result.ForUser)
}
}
// TestShellTool_MissingCommand verifies error handling for missing command
func TestShellTool_MissingCommand(t *testing.T) {
tool, err := NewExecTool("", false)
+27 -17
View File
@@ -8,25 +8,18 @@ import (
type SpawnTool struct {
manager *SubagentManager
originChannel string
originChatID string
allowlistCheck func(targetAgentID string) bool
callback AsyncCallback // For async completion notification
}
// Compile-time check: SpawnTool implements AsyncExecutor.
var _ AsyncExecutor = (*SpawnTool)(nil)
func NewSpawnTool(manager *SubagentManager) *SpawnTool {
return &SpawnTool{
manager: manager,
originChannel: "cli",
originChatID: "direct",
manager: manager,
}
}
// SetCallback implements AsyncTool interface for async completion notification
func (t *SpawnTool) SetCallback(cb AsyncCallback) {
t.callback = cb
}
func (t *SpawnTool) Name() string {
return "spawn"
}
@@ -56,16 +49,21 @@ func (t *SpawnTool) Parameters() map[string]any {
}
}
func (t *SpawnTool) SetContext(channel, chatID string) {
t.originChannel = channel
t.originChatID = chatID
}
func (t *SpawnTool) SetAllowlistChecker(check func(targetAgentID string) bool) {
t.allowlistCheck = check
}
func (t *SpawnTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
return t.execute(ctx, args, nil)
}
// ExecuteAsync implements AsyncExecutor. The callback is passed through to the
// subagent manager as a call parameter — never stored on the SpawnTool instance.
func (t *SpawnTool) ExecuteAsync(ctx context.Context, args map[string]any, cb AsyncCallback) *ToolResult {
return t.execute(ctx, args, cb)
}
func (t *SpawnTool) execute(ctx context.Context, args map[string]any, cb AsyncCallback) *ToolResult {
task, ok := args["task"].(string)
if !ok || strings.TrimSpace(task) == "" {
return ErrorResult("task is required and must be a non-empty string")
@@ -85,8 +83,20 @@ func (t *SpawnTool) Execute(ctx context.Context, args map[string]any) *ToolResul
return ErrorResult("Subagent manager not configured")
}
// Read channel/chatID from context (injected by registry).
// Fall back to "cli"/"direct" for non-conversation callers (e.g., CLI, tests)
// to preserve the same defaults as the original NewSpawnTool constructor.
channel := ToolChannel(ctx)
if channel == "" {
channel = "cli"
}
chatID := ToolChatID(ctx)
if chatID == "" {
chatID = "direct"
}
// Pass callback to manager for async completion notification
result, err := t.manager.Spawn(ctx, task, label, agentID, t.originChannel, t.originChatID, t.callback)
result, err := t.manager.Spawn(ctx, task, label, agentID, channel, chatID, cb)
if err != nil {
return ErrorResult(fmt.Sprintf("failed to spawn subagent: %v", err))
}
+2 -2
View File
@@ -8,7 +8,7 @@ import (
func TestSpawnTool_Execute_EmptyTask(t *testing.T) {
provider := &MockLLMProvider{}
manager := NewSubagentManager(provider, "test-model", "/tmp/test", nil)
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
tool := NewSpawnTool(manager)
ctx := context.Background()
@@ -42,7 +42,7 @@ func TestSpawnTool_Execute_EmptyTask(t *testing.T) {
func TestSpawnTool_Execute_ValidTask(t *testing.T) {
provider := &MockLLMProvider{}
manager := NewSubagentManager(provider, "test-model", "/tmp/test", nil)
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
tool := NewSpawnTool(manager)
ctx := context.Background()
+14 -30
View File
@@ -6,7 +6,6 @@ import (
"sync"
"time"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/providers"
)
@@ -27,7 +26,6 @@ type SubagentManager struct {
mu sync.RWMutex
provider providers.LLMProvider
defaultModel string
bus *bus.MessageBus
workspace string
tools *ToolRegistry
maxIterations int
@@ -41,13 +39,11 @@ type SubagentManager struct {
func NewSubagentManager(
provider providers.LLMProvider,
defaultModel, workspace string,
bus *bus.MessageBus,
) *SubagentManager {
return &SubagentManager{
tasks: make(map[string]*SubagentTask),
provider: provider,
defaultModel: defaultModel,
bus: bus,
workspace: workspace,
tools: NewToolRegistry(),
maxIterations: 10,
@@ -214,20 +210,6 @@ After completing the task, provide a clear summary of what was done.`
Async: false,
}
}
// Send announce message back to main agent
if sm.bus != nil {
announceContent := fmt.Sprintf("Task '%s' completed.\n\nResult:\n%s", task.Label, task.Result)
pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second)
defer pubCancel()
sm.bus.PublishInbound(pubCtx, bus.InboundMessage{
Channel: "system",
SenderID: fmt.Sprintf("subagent:%s", task.ID),
// Format: "original_channel:original_chat_id" for routing back
ChatID: fmt.Sprintf("%s:%s", task.OriginChannel, task.OriginChatID),
Content: announceContent,
})
}
}
func (sm *SubagentManager) GetTask(taskID string) (*SubagentTask, bool) {
@@ -252,16 +234,12 @@ func (sm *SubagentManager) ListTasks() []*SubagentTask {
// Unlike SpawnTool which runs tasks asynchronously, SubagentTool waits for completion
// and returns the result directly in the ToolResult.
type SubagentTool struct {
manager *SubagentManager
originChannel string
originChatID string
manager *SubagentManager
}
func NewSubagentTool(manager *SubagentManager) *SubagentTool {
return &SubagentTool{
manager: manager,
originChannel: "cli",
originChatID: "direct",
manager: manager,
}
}
@@ -290,11 +268,6 @@ func (t *SubagentTool) Parameters() map[string]any {
}
}
func (t *SubagentTool) SetContext(channel, chatID string) {
t.originChannel = channel
t.originChatID = chatID
}
func (t *SubagentTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
task, ok := args["task"].(string)
if !ok {
@@ -341,13 +314,24 @@ func (t *SubagentTool) Execute(ctx context.Context, args map[string]any) *ToolRe
}
}
// Fall back to "cli"/"direct" for non-conversation callers (e.g., CLI, tests)
// to preserve the same defaults as the original NewSubagentTool constructor.
channel := ToolChannel(ctx)
if channel == "" {
channel = "cli"
}
chatID := ToolChatID(ctx)
if chatID == "" {
chatID = "direct"
}
loopResult, err := RunToolLoop(ctx, ToolLoopConfig{
Provider: sm.provider,
Model: sm.defaultModel,
Tools: tools,
MaxIterations: maxIter,
LLMOptions: llmOptions,
}, messages, t.originChannel, t.originChatID)
}, messages, channel, chatID)
if err != nil {
return ErrorResult(fmt.Sprintf("Subagent execution failed: %v", err)).WithError(err)
}
+12 -35
View File
@@ -5,7 +5,6 @@ import (
"strings"
"testing"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/providers"
)
@@ -47,12 +46,11 @@ func (m *MockLLMProvider) GetContextWindow() int {
func TestSubagentManager_SetLLMOptions_AppliesToRunToolLoop(t *testing.T) {
provider := &MockLLMProvider{}
manager := NewSubagentManager(provider, "test-model", "/tmp/test", nil)
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
manager.SetLLMOptions(2048, 0.6)
tool := NewSubagentTool(manager)
tool.SetContext("cli", "direct")
ctx := context.Background()
ctx := WithToolContext(context.Background(), "cli", "direct")
args := map[string]any{"task": "Do something"}
result := tool.Execute(ctx, args)
@@ -74,7 +72,7 @@ func TestSubagentManager_SetLLMOptions_AppliesToRunToolLoop(t *testing.T) {
// TestSubagentTool_Name verifies tool name
func TestSubagentTool_Name(t *testing.T) {
provider := &MockLLMProvider{}
manager := NewSubagentManager(provider, "test-model", "/tmp/test", nil)
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
tool := NewSubagentTool(manager)
if tool.Name() != "subagent" {
@@ -85,7 +83,7 @@ func TestSubagentTool_Name(t *testing.T) {
// TestSubagentTool_Description verifies tool description
func TestSubagentTool_Description(t *testing.T) {
provider := &MockLLMProvider{}
manager := NewSubagentManager(provider, "test-model", "/tmp/test", nil)
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
tool := NewSubagentTool(manager)
desc := tool.Description()
@@ -100,7 +98,7 @@ func TestSubagentTool_Description(t *testing.T) {
// TestSubagentTool_Parameters verifies tool parameters schema
func TestSubagentTool_Parameters(t *testing.T) {
provider := &MockLLMProvider{}
manager := NewSubagentManager(provider, "test-model", "/tmp/test", nil)
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
tool := NewSubagentTool(manager)
params := tool.Parameters()
@@ -147,28 +145,13 @@ func TestSubagentTool_Parameters(t *testing.T) {
}
}
// TestSubagentTool_SetContext verifies context setting
func TestSubagentTool_SetContext(t *testing.T) {
provider := &MockLLMProvider{}
manager := NewSubagentManager(provider, "test-model", "/tmp/test", nil)
tool := NewSubagentTool(manager)
tool.SetContext("test-channel", "test-chat")
// Verify context is set (we can't directly access private fields,
// but we can verify it doesn't crash)
// The actual context usage is tested in Execute tests
}
// TestSubagentTool_Execute_Success tests successful execution
func TestSubagentTool_Execute_Success(t *testing.T) {
provider := &MockLLMProvider{}
msgBus := bus.NewMessageBus()
manager := NewSubagentManager(provider, "test-model", "/tmp/test", msgBus)
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
tool := NewSubagentTool(manager)
tool.SetContext("telegram", "chat-123")
ctx := context.Background()
ctx := WithToolContext(context.Background(), "telegram", "chat-123")
args := map[string]any{
"task": "Write a haiku about coding",
"label": "haiku-task",
@@ -219,8 +202,7 @@ func TestSubagentTool_Execute_Success(t *testing.T) {
// TestSubagentTool_Execute_NoLabel tests execution without label
func TestSubagentTool_Execute_NoLabel(t *testing.T) {
provider := &MockLLMProvider{}
msgBus := bus.NewMessageBus()
manager := NewSubagentManager(provider, "test-model", "/tmp/test", msgBus)
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
tool := NewSubagentTool(manager)
ctx := context.Background()
@@ -243,7 +225,7 @@ func TestSubagentTool_Execute_NoLabel(t *testing.T) {
// TestSubagentTool_Execute_MissingTask tests error handling for missing task
func TestSubagentTool_Execute_MissingTask(t *testing.T) {
provider := &MockLLMProvider{}
manager := NewSubagentManager(provider, "test-model", "/tmp/test", nil)
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
tool := NewSubagentTool(manager)
ctx := context.Background()
@@ -293,16 +275,12 @@ func TestSubagentTool_Execute_NilManager(t *testing.T) {
// TestSubagentTool_Execute_ContextPassing verifies context is properly used
func TestSubagentTool_Execute_ContextPassing(t *testing.T) {
provider := &MockLLMProvider{}
msgBus := bus.NewMessageBus()
manager := NewSubagentManager(provider, "test-model", "/tmp/test", msgBus)
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
tool := NewSubagentTool(manager)
// Set context
channel := "test-channel"
chatID := "test-chat"
tool.SetContext(channel, chatID)
ctx := context.Background()
ctx := WithToolContext(context.Background(), channel, chatID)
args := map[string]any{
"task": "Test context passing",
}
@@ -322,8 +300,7 @@ func TestSubagentTool_Execute_ContextPassing(t *testing.T) {
func TestSubagentTool_ForUserTruncation(t *testing.T) {
// Create a mock provider that returns very long content
provider := &MockLLMProvider{}
msgBus := bus.NewMessageBus()
manager := NewSubagentManager(provider, "test-model", "/tmp/test", msgBus)
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
tool := NewSubagentTool(manager)
ctx := context.Background()
+71 -1
View File
@@ -395,6 +395,68 @@ func (p *PerplexitySearchProvider) Search(ctx context.Context, query string, cou
return fmt.Sprintf("Results for: %s (via Perplexity)\n%s", query, searchResp.Choices[0].Message.Content), nil
}
type SearXNGSearchProvider struct {
baseURL string
}
func (p *SearXNGSearchProvider) Search(ctx context.Context, query string, count int) (string, error) {
searchURL := fmt.Sprintf("%s/search?q=%s&format=json&categories=general",
strings.TrimSuffix(p.baseURL, "/"),
url.QueryEscape(query))
req, err := http.NewRequestWithContext(ctx, "GET", searchURL, nil)
if err != nil {
return "", fmt.Errorf("failed to create request: %w", err)
}
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Do(req)
if err != nil {
return "", fmt.Errorf("request failed: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("SearXNG returned status %d", resp.StatusCode)
}
var result struct {
Results []struct {
Title string `json:"title"`
URL string `json:"url"`
Content string `json:"content"`
Engine string `json:"engine"`
Score float64 `json:"score"`
} `json:"results"`
}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return "", fmt.Errorf("failed to parse response: %w", err)
}
if len(result.Results) == 0 {
return fmt.Sprintf("No results for: %s", query), nil
}
// Limit results to requested count
if len(result.Results) > count {
result.Results = result.Results[:count]
}
// Format results in standard PicoClaw format
var b strings.Builder
b.WriteString(fmt.Sprintf("Results for: %s (via SearXNG)\n", query))
for i, r := range result.Results {
b.WriteString(fmt.Sprintf("%d. %s\n", i+1, r.Title))
b.WriteString(fmt.Sprintf(" %s\n", r.URL))
if r.Content != "" {
b.WriteString(fmt.Sprintf(" %s\n", r.Content))
}
}
return b.String(), nil
}
type GLMSearchProvider struct {
apiKey string
baseURL string
@@ -495,6 +557,9 @@ type WebSearchToolOptions struct {
PerplexityAPIKey string
PerplexityMaxResults int
PerplexityEnabled bool
SearXNGBaseURL string
SearXNGMaxResults int
SearXNGEnabled bool
GLMSearchAPIKey string
GLMSearchBaseURL string
GLMSearchEngine string
@@ -507,7 +572,7 @@ func NewWebSearchTool(opts WebSearchToolOptions) (*WebSearchTool, error) {
var provider SearchProvider
maxResults := 5
// Priority: Perplexity > Brave > Tavily > DuckDuckGo > GLM Search
// Priority: Perplexity > Brave > SearXNG > Tavily > DuckDuckGo > GLM Search
if opts.PerplexityEnabled && opts.PerplexityAPIKey != "" {
client, err := createHTTPClient(opts.Proxy, perplexityTimeout)
if err != nil {
@@ -526,6 +591,11 @@ func NewWebSearchTool(opts WebSearchToolOptions) (*WebSearchTool, error) {
if opts.BraveMaxResults > 0 {
maxResults = opts.BraveMaxResults
}
} else if opts.SearXNGEnabled && opts.SearXNGBaseURL != "" {
provider = &SearXNGSearchProvider{baseURL: opts.SearXNGBaseURL}
if opts.SearXNGMaxResults > 0 {
maxResults = opts.SearXNGMaxResults
}
} else if opts.TavilyEnabled && opts.TavilyAPIKey != "" {
client, err := createHTTPClient(opts.Proxy, searchTimeout)
if err != nil {